「線形代数の基礎」をJavaで実装してみる2

線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html

スカラ値の乗算

ベクトル

    public Vector multiply(float scalar) {
        float[] scalars = new float[this.scalars.length];
        for (int i = 0;i < this.scalars.length;i++) {
            scalars[i] = this.scalars[i] * scalar;
        }
        return new Vector(scalars);
    }

行列

    public Matrix multiply(float scalar) {
        float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length];

        for (int i = 0;i < this.o2scalars.length;i++) {
            for (int j = 0;j < this.o2scalars[i].length;j++) {
                o2scalars[i][j] = this.o2scalars[i][j] * scalar;
            }
        }
        return new Matrix(o2scalars);
    }

各要素にスカラ値を掛けます。

Vector v1 = new Vector(new float[] {1, 2, 3});
Vector v2 = v1.multiply(10);
System.out.println(v2);

Matrix m1 = new Matrix(new float[][] {
        {1, 2, 3},
        {4, 5, 6},
});
Matrix m2 = m1.multiply(10);
System.out.println(m2);

実行結果

[10.0, 20.0, 30.0]
2 x 3
| 10.0| 20.0| 30.0|
| 40.0| 50.0| 60.0|

次はベクトルの内積です

    public float innerProduct(Vector object) {
        if (this.isVertical || !object.isVertical) {
            throw new RuntimeException("vertical error");
        }

        float sum = 0;
        for (int i = 0;i < scalars.length;i++) {
            sum += scalars[i] * object.scalars[i];
        }

        return sum;
    }

内積は横ベクトルと縦ベクトルの積の場合可能です。

        Vector v10 = new Vector(new float[] {1, 2, 3}, false);
        Vector v11 = new Vector(new float[] {4, 5, 6}, true);
        float result = v10.innerProduct(v11);
        System.out.println(result);

実行結果

32.0

こちらが行列積になります。

    public Matrix matrixMultiplication(Matrix object) {
        final float[][] our = o2scalars;
        float[][] newScalars = new float[our.length][our[0].length];
        for (int i = 0;i < our.length;i++) {
            for (int j = 0;j < our[0].length;j++) {

                float sum = 0;
                for (int k = 0;k < our.length;k++) {
                    sum += our[i][k] * object.o2scalars[k][j];
                }
                newScalars[i][j] = sum;
            }
        }

        return new Matrix(newScalars);
    }
        Matrix m10 = new Matrix(new float[][] {
                {1, 2},
                {3, 4},
        });
        Matrix m11 = new Matrix(new float[][] {
                {5, 6},
                {7, 8},
        });
        System.out.println(m10.matrixMultiplication(m11));

結果

2 x 2
| 19.0| 22.0|
| 43.0| 50.0|

積算の実装については以上です。

「線形代数の基礎」をJavaで実装してみる

線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html

テンソル

public class Tensor {
    protected final int order;
    public Tensor(int order) {
        this.order = order;
    }
}
public class LinearAlgebraTest {
    public static void main(String[] args) {
        Tensor o1 = new Tensor(1); // 1階のテンソル
        Tensor o2 = new Tensor(2); // 2階のテンソル
    }
}

order は N階のテンソルを表します。 Java的にはN次元の配列ということになります。

ベクトル

ベクトルクラスを定義します。

public class Vector extends Tensor {
    private final float[] scalars;
    public Vector(float[] scalars) {
        super(1); // ベクトルは1階のテンソル
        this.scalars = scalars;
    }
}

ベクトルはテンソルを継承して、order は 1 固定です。

加算を実装します。

Vectorに以下を追加

    public Vector add(Vector object) {
        float[] scalars = new float[this.scalars.length];
        for (int i = 0;i < this.scalars.length;i++) {
            scalars[i] = this.scalars[i] + object.scalars[i];
        }
        return new Vector(scalars);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        if (scalars.length > 0) {
            sb.append(scalars[0]);
            for (int i = 1;i < scalars.length;i++) {
                sb.append(", ").append(scalars[i]);
            }
        }
        sb.append("]");
        return sb.toString();
    }

呼び出し

        Vector v1 = new Vector(new float[]{1, 2, 3});
        Vector v2 = new Vector(new float[]{4, 5, 6});

        System.out.println(v1);
        System.out.println(v2);

        Vector v3 = v1.add(v2);
        System.out.println(v3);

実行結果

[1.0, 2.0, 3.0]
[4.0, 5.0, 6.0]
[5.0, 7.0, 9.0]

行列

こんな感じの実装にしてみます。

public class Matrix extends Tensor {
    private final float[][] o2scalars;
    public Matrix(float[][] o2scalars) {
        super(2); // 行列は2階のテンソル
        this.o2scalars = o2scalars;
    }
}

加算を実装してみます。

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(o2scalars.length).append(" x ").append(o2scalars[0].length).append("\n");
        for (int i = 0;i < o2scalars.length;i++) {
            sb.append("|");
            for (int j = 0;j < o2scalars[i].length;j++) {
                sb.append(String.format("% 5.1f|", o2scalars[i][j]));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    public Matrix add(Matrix object) {
        float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length];

        for (int i = 0;i < this.o2scalars.length;i++) {
            for (int j = 0;j < this.o2scalars[i].length;j++) {
                o2scalars[i][j] = this.o2scalars[i][j] + object.o2scalars[i][j];
            }
        }
        return new Matrix(o2scalars);
    }
        Matrix m1 = new Matrix(new float[][] {
                {1,2,3},
                {4,5,6},
        });
        Matrix m2 = new Matrix(new float[][] {
                {7,8,9},
                {10,11,12},
        });
        System.out.println(m1);
        System.out.println(m2);

        Matrix m3 = m1.add(m2);
        System.out.println(m3);

実行結果です。

2 x 3
|  1.0|  2.0|  3.0|
|  4.0|  5.0|  6.0|

2 x 3
|  7.0|  8.0|  9.0|
| 10.0| 11.0| 12.0|

2 x 3
|  8.0| 10.0| 12.0|
| 14.0| 16.0| 18.0|

次回は行列の積を実装してみたいと思います。

近似計算の比較

計算をするときいくつかの処理を高速化のために近似値計算で済ます方法があります。 こちらのブログで計算方法が紹介されているので拝借します。 https://martin.ankerl.com/2007/10/04/optimized-pow-approximation-for-java-and-c-c/

package math;

import org.apache.commons.math3.util.FastMath;

public class PowTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.pow(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.pow", sum, (System.currentTimeMillis() - start));
        }
        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.pow(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.pow", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow1(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow1", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow2(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow2", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow3(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow3", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 3;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += pow4(j, 2);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "pow4", sum, (System.currentTimeMillis() - start));
        }
    }

    public static double pow1(final double a, final double b) {
        final int x = (int) (Double.doubleToLongBits(a) >> 32);
        final int y = (int) (b * (x - 1072632447) + 1072632447);
        return Double.longBitsToDouble(((long) y) << 32);
    }

    public static double pow2(final double a, final double b) {
        final long tmp = Double.doubleToLongBits(a);
        final long tmp2 = (long)(b * (tmp - 4606921280493453312L)) + 4606921280493453312L;
        return Double.longBitsToDouble(tmp2);
    }

    public static double pow3(final double a, final double b) {
        final double x = (Double.doubleToLongBits(a) >> 32);
        final long tmp2 = (long) (1512775 * (x - 1072632447) / 1512775 * b + (1072693248 - 60801));
        return Double.longBitsToDouble(tmp2 << 32);
    }
    public static double pow4(final double a, final double b) {
        final int tmp = (int) (Double.doubleToLongBits(a) >> 32);
        final int tmp2 = (int) (b * (tmp - 1072632447) + 1072632447);
        return Double.longBitsToDouble(((long) tmp2) << 32);
    }
}

macOS Oracle JDK 12.0.2で実行しました。

Math.pow      sum= 333332833333127550 7ms
Math.pow      sum= 333332833333127550 5ms
Math.pow      sum= 333332833333127550 4ms
FastMath.pow  sum= 333332833333127550 165ms
FastMath.pow  sum= 333332833333127550 148ms
FastMath.pow  sum= 333332833333127550 144ms
pow1          sum= 332595653188566270 10ms
pow1          sum= 332595653188566270 9ms
pow1          sum= 332595653188566270 9ms
pow2          sum= 332595653188566270 9ms
pow2          sum= 332595653188566270 7ms
pow2          sum= 332595653188566270 8ms
pow3          sum= 332595653188566270 16ms
pow3          sum= 332595653188566270 16ms
pow3          sum= 332595653188566270 15ms
pow4          sum= 332595653188566270 10ms
pow4          sum= 332595653188566270 10ms
pow4          sum= 332595653188566270 10ms

Math.powが最速でした。FastMath.powはかなり遅いです。 pow1 ~ pow4は計算結果は変わらないですが、速度に差が出ています。

次はexpです。

package math;

import org.apache.commons.math3.util.FastMath;

public class ExpTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.exp(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.exp(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 5;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += exp1(j * 0.0001f);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "exp1", sum, (System.currentTimeMillis() - start));
        }
    }
    public static double exp1(double val) {
        final long tmp = (long) (1512775 * val + (1072693248 - 60801));
        return Double.longBitsToDouble(tmp << 32);
    }
}

結果です。

Math.exp      sum= 268797601492947300000000000000000000000000000000 15ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 24ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
Math.exp      sum= 268797601492947300000000000000000000000000000000 23ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 45ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 30ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 30ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 31ms
FastMath.exp  sum= 268797601492947300000000000000000000000000000000 31ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms
exp1          sum= 267998235537123060000000000000000000000000000000 8ms
exp1          sum= 267998235537123060000000000000000000000000000000 8ms
exp1          sum= 267998235537123060000000000000000000000000000000 9ms

今回はexp1が最速でした。FastMath.expの結果は悪く無いですが、一番遅かったです。

最後はsqrtです

package math;

import org.apache.commons.math3.util.FastMath;

public class SqrtTest1 {
    public static void main(String[] args) {
        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += Math.sqrt(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += FastMath.sqrt(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.sqrt", sum, (System.currentTimeMillis() - start));
        }

        for (int i = 0;i < 4;i++) {
            long start;
            double sum = 0;
            start = System.currentTimeMillis();
            for (int j = 0;j < 1000000;j++) {
                sum += sqrt1(j);
            }
            System.out.format("%-13s sum= %8.0f %dms\n", "sqrt1", sum, (System.currentTimeMillis() - start));
        }
    }
    public static double sqrt1(final double a) {
        final long x = Double.doubleToLongBits(a) >> 32;
        double y = Double.longBitsToDouble((x + 1072632448) << 31);
        return y;
    }
}

結果です。

Math.exp      sum= 666666166 7ms
Math.exp      sum= 666666166 9ms
Math.exp      sum= 666666166 5ms
Math.exp      sum= 666666166 3ms
FastMath.sqrt sum= 666666166 10ms
FastMath.sqrt sum= 666666166 8ms
FastMath.sqrt sum= 666666166 7ms
FastMath.sqrt sum= 666666166 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms
sqrt1         sum= 666888545 4ms

最速はMath.expでした。 アルゴリズム的には似通ったものになるのか差はほとんどありませんでした。

GraalVMのnative-imageを試す

https://www.graalvm.org/docs/getting-started/

こちらの手順に沿ってダウンロードして、エイリアスの設定

alias java8=~/graalvm-ce-19.2.0.1/Contents/Home/bin/java
alias javac8=~/graalvm-ce-19.2.0.1/Contents/Home/bin/javac

以下のソースで実験

// VectorTest3.java
public class VectorTest3 {
    private final static int NUM = 100 * 1000 * 1000;

    private static float dotProduct(float[] vec_a, float[] vec_b) {
        float sum = 0;
        for (int i = 0; i < vec_a.length; i++) {
            sum += vec_a[i] * vec_b[i];
        }
        return sum;
    }

    private static void bench1() {
        float[] vec_a = new float[NUM];
        float[] vec_b = new float[NUM];
        for (int i = 0;i < NUM;i++) {
            vec_a[i] = (float)Math.random();
            vec_b[i] = (float)Math.random();
        }
        long start = System.currentTimeMillis();
        float sum = dotProduct(vec_a, vec_b);
        System.out.format("bench1 - %d ms\n", (System.currentTimeMillis() - start));
    }
    public static void main(String[] args) {
        for(int i = 0;i < 10;i++) {
            bench1();
        }
    }
}

前回のpanamaビルドで実行

$ javac14 src/main/java/VectorTest3.java \
                                          -d out/
$ java14 -cp out/ VectorTest3
bench1 - 137 ms
bench1 - 124 ms
bench1 - 130 ms
bench1 - 112 ms
bench1 - 121 ms
bench1 - 130 ms
bench1 - 120 ms
bench1 - 120 ms
bench1 - 122 ms
bench1 - 117 ms

最速で117msです。 GraalVMの方では

$ javac8 src/main/java/VectorTest3.java \
                                          -d out/
$ java8 -cp out/ VectorTest3
bench1 - 131 ms
bench1 - #
# A fatal error has been detected by the Java Runtime Environment:
#
#  Internal Error (deoptimization.cpp:808), pid=62160, tid=0x0000000000001603
#  fatal error: java/lang/Long$LongCache must be initialized

クラッシュ...

native-imageをします。

# インストール
$ ~/graalvm-ce-19.2.0.1/Contents/Home/bin/gu install native-image

$  ~/graalvm-ce-19.2.0.1/Contents/Home/bin/native-image -cp out/ VectorTest3
Build on Server(pid: 44023, port: 50514)
[vectortest3:44023]    classlist:     143.75 ms
[vectortest3:44023]        (cap):   2,228.31 ms
[vectortest3:44023]        setup:   3,358.35 ms
[vectortest3:44023]   (typeflow):   2,793.14 ms
[vectortest3:44023]    (objects):   2,172.78 ms
[vectortest3:44023]   (features):     170.70 ms
[vectortest3:44023]     analysis:   5,226.55 ms
[vectortest3:44023]     (clinit):      96.69 ms
[vectortest3:44023]     universe:     324.64 ms
[vectortest3:44023]      (parse):     432.65 ms
[vectortest3:44023]     (inline):     969.11 ms
[vectortest3:44023]    (compile):   4,320.70 ms
[vectortest3:44023]      compile:   6,017.41 ms
[vectortest3:44023]        image:     490.40 ms
[vectortest3:44023]        write:     184.50 ms
[vectortest3:44023]      [total]:  15,869.97 ms

$ ls -la vectortest3
-rwxr-xr-x  1 tak  staff  4496880 Oct  5 11:00 vectortest3

約4MBの実行ファイルが生成されました。 実行します。

$ ./vectortest3
bench1 - 127 ms
bench1 - 130 ms
bench1 - 125 ms
bench1 - 126 ms
bench1 - 126 ms
bench1 - 127 ms
bench1 - 136 ms
bench1 - 126 ms
bench1 - 129 ms
bench1 - 129 ms

なんか遅くなっている気がします。

$ otool -L vectortest3
vectortest3:
    /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1281.0.0)
    /System/Library/Frameworks/CoreFoundation.framework/Versions/A/CoreFoundation (compatibility version 150.0.0, current version 1670.10.0)
    /usr/lib/libz.1.dylib (compatibility version 1.0.0, current version 1.2.11)

静的リンクはlibzが必要そうです。

objdump -d vectortest3  | grep vfmadd

vfmaddでgrepしましたが特に使われていないのでベクトルができていない模様

Java Vector API を試す

https://nowokay.hatenablog.com/entry/2019/09/05/015537

こちらの記事を参考にビルドします。

CのコードとJavaのコードを比較していきます。

# vector.h
float dotProduct512(float* vec1, float* vec2, int num);
float dotProduct256(float* vec1, float* vec2, int num);
float dotProduct(float* vec1, float* vec2, int num);
#vector.c
#include "vector.h"
#include <immintrin.h>

float dotProduct512(float* vec1, float* vec2, int num)
{
    __m512 avx_sum = _mm512_setzero_ps();
    int limit = num - num % 16;
    for (int i = 0;i < limit;i += 16) {
        const __m512 a512 = _mm512_loadu_ps((float*)&vec1[i]);
        const __m512 b512 = _mm512_loadu_ps((float*)&vec2[i]);
        avx_sum = _mm512_fmadd_ps(a512, b512, avx_sum);
    }

    float __attribute__((aligned(32))) out[16] = {};
    _mm512_storeu_ps(out, avx_sum);
    float sum = 0;
    for (int i = 0;i < 16;i++) {
        sum += out[i];
    }
    for (int i = limit;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}

float dotProduct256(float* vec1, float* vec2, int num)
{
    __m256 avx_sum = _mm256_setzero_ps();
    int limit = num - num % 8;
    for (int i = 0;i < limit;i += 8) {
        const __m256 a256 = _mm256_loadu_ps((float*)&vec1[i]);
        const __m256 b256 = _mm256_loadu_ps((float*)&vec2[i]);
        avx_sum = _mm256_fmadd_ps(a256, b256, avx_sum);
    }

    float __attribute__((aligned(32))) out[16] = {};
    _mm256_store_ps(out, avx_sum);
    float sum = 0;
    for (int i = 0;i < 8;i++) {
        sum += out[i];
    }
    for (int i = limit;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}

float dotProduct(float* vec1, float* vec2, int num)
{
    float sum = 0;
    for (int i = 0;i < num;i++) {
        sum += vec1[i] * vec2[i];
    }
    return sum;
}
# vec_test.c

#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include "vector.h"

#define NUM 100 * 1000 * 1000

unsigned long getMicroSec()
{
    struct timespec time1;
    clock_gettime(CLOCK_REALTIME,&time1);

    unsigned long micros = time1.tv_sec * 1000000;
    micros += time1.tv_nsec / 1000;
    return micros;
}

void bench1()
{

    float *vec_a, *vec_b;
    vec_a = (float*)malloc(sizeof(float) * NUM);
    vec_b = (float*)malloc(sizeof(float) * NUM);

    for(int i = 0;i < NUM;i++) {
        vec_a[i] = ((float)rand()) / NUM;
        vec_b[i] = ((float)rand()) / NUM;
    }

    unsigned long start = getMicroSec();
    float sum = dotProduct(vec_a, vec_b, NUM);
    printf("bench1 - %lu ms\n", (getMicroSec() - start) / 1000);

    free(vec_a);
    free(vec_b);
}

void bench2()
{

    float *vec_a, *vec_b;
    vec_a = (float*)malloc(sizeof(float) * NUM);
    vec_b = (float*)malloc(sizeof(float) * NUM);

    for(int i = 0;i < NUM;i++) {
        vec_a[i] = ((float)rand()) / NUM;
        vec_b[i] = ((float)rand()) / NUM;
    }
    unsigned long start = getMicroSec();

    float sum = dotProduct256(vec_a, vec_b, NUM);
    printf("bench2 - %lu ms\n", (getMicroSec() - start) / 1000);

    free(vec_a);
    free(vec_b);
}

int main(void)
{
    bench1();
    bench2();
    return 0;
}

以下のコマンドでビルドします。

gcc -O2 -mavx512f vec_test.c  -o vec_test vector
gcc -O2 vec_test.c  -o vec_test vector

実行結果です。

 $ ./vec_test
bench1 - 135 ms
bench2 - 52 ms

FMA命令を使用した方が早くなります。 次にJavaです。

// VectorTest2.java
import jdk.incubator.vector.*;

public class VectorTest2 {
    private final static int NUM = 100 * 1000 * 1000;

    private static float dotProduct256(float[] vec_a, float[] vec_b) {
        var SP = FloatVector.SPECIES_256;

        int limit = vec_a.length - vec_a.length % 8;
        var fv_sum = FloatVector.fromValues(SP, 0, 0, 0, 0, 0, 0, 0, 0);
        for (int i = 0; i < limit; i += 8) {
            var fv_a = FloatVector.fromArray(SP, vec_a, i);
            var fv_b = FloatVector.fromArray(SP, vec_b, i);
            fv_sum = fv_a.fma(fv_b, fv_sum);
        }

        float[] outArray = new float[8];
        fv_sum.intoArray(outArray, 0);

        float sum = 0;
        for (float f: outArray) {
            sum += f;
        }

        for (int i = limit; i < vec_a.length; i += 1) {
            sum += vec_a[i] * vec_b[i];
        }
        return sum;
    }
    private static float dotProduct(float[] vec_a, float[] vec_b) {
        float sum = 0;
        for (int i = 0; i < vec_a.length; i++) {
            sum += vec_a[i] * vec_b[i];
        }
        return sum;
    }

    private static void bench1() {
        float[] vec_a = new float[NUM];
        float[] vec_b = new float[NUM];
        for (int i = 0;i < NUM;i++) {
            vec_a[i] = (float)Math.random();
            vec_b[i] = (float)Math.random();
        }
        long start = System.currentTimeMillis();
        float sum = dotProduct(vec_a, vec_b);
        System.out.format("bench1 - %d ms\n", (System.currentTimeMillis() - start));
    }
    private static void bench2() {
        float[] vec_a = new float[NUM];
        float[] vec_b = new float[NUM];
        for (int i = 0;i < NUM;i++) {
            vec_a[i] = (float)Math.random();
            vec_b[i] = (float)Math.random();
        }
        long start = System.currentTimeMillis();
        float sum = dotProduct256(vec_a, vec_b);
        System.out.format("bench2 - %d ms\n", (System.currentTimeMillis() - start));
    }
    public static void main(String[] args) {
        for(int i = 0;i < 20;i++) {
            bench1();
        }
        for(int i = 0;i < 20;i++) {
            bench2();
        }
    }
}

ビルドして実行します。

$ javac14 src/main/java/VectorTest2.java \
                    --add-modules jdk.incubator.vector \
                    -d out/
$ java14 -cp out/ VectorTest2
bench1 - 114 ms
bench1 - 127 ms
bench1 - 116 ms
bench1 - 128 ms
bench1 - 108 ms
bench1 - 140 ms
bench1 - 117 ms
bench1 - 114 ms
bench1 - 112 ms
bench1 - 109 ms
bench1 - 144 ms
bench1 - 120 ms
bench1 - 120 ms
bench1 - 110 ms
bench1 - 117 ms
bench1 - 107 ms
bench1 - 107 ms
bench1 - 110 ms
bench1 - 107 ms
bench1 - 109 ms
bench2 - 447 ms
bench2 - 125 ms
bench2 - 163 ms
bench2 - 110 ms
bench2 - 117 ms
bench2 - 117 ms
bench2 - 117 ms
bench2 - 116 ms
bench2 - 123 ms
bench2 - 116 ms
bench2 - 129 ms
bench2 - 116 ms
bench2 - 116 ms
bench2 - 116 ms
bench2 - 117 ms
bench2 - 115 ms
bench2 - 117 ms
bench2 - 116 ms
bench2 - 118 ms
bench2 - 116 ms

VectorAPIを使用しない場合107msが最速ですが、使用した場合116msが最速です。 なぜ使用すると遅くなるのかは不明ですが、まだまだ最適化が必要なのでしょう。


Intel公式資料

Java Doc

AVX512をEC2で試す

今回はEC2のc5.largeインスタンスを使います。

CPU情報は

$ cat /proc/cpuinfo
processor   : 0
vendor_id   : GenuineIntel
cpu family  : 6
model       : 85
model name  : Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz
stepping    : 4
microcode   : 0x200005e
cpu MHz     : 3408.548
cache size  : 25344 KB
physical id : 0
siblings    : 2
core id     : 0
cpu cores   : 1
apicid      : 0
initial apicid  : 0
fpu     : yes
fpu_exception   : yes
cpuid level : 13
wp      : yes
flags       : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
bugs        : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds
bogomips    : 5999.99
clflush size    : 64
cache_alignment : 64
address sizes   : 46 bits physical, 48 bits virtual
power management:

processor   : 1
vendor_id   : GenuineIntel
cpu family  : 6
model       : 85
model name  : Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz
stepping    : 4
microcode   : 0x200005e
cpu MHz     : 3411.384
cache size  : 25344 KB
physical id : 0
siblings    : 2
core id     : 0
cpu cores   : 1
apicid      : 1
initial apicid  : 1
fpu     : yes
fpu_exception   : yes
cpuid level : 13
wp      : yes
flags       : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
bugs        : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds
bogomips    : 5999.99
clflush size    : 64
cache_alignment : 64
address sizes   : 46 bits physical, 48 bits virtual
power management:

c5の上位インスタンスではIntel DL Boostに対応していますが、下位インスタンスでは対応していません。

f:id:taku-woohar:20190922175408p:plain
Intel DL Boost対応表
2019/7/24のIntel AIのブログです。

実行のコードです。

vector.h

float sumProduct(float* vec1, float* vec2, int num);

vector.c

#include "vector.h"
#include <immintrin.h>

float sumProduct(float* vec1, float* vec2, int num)
{
    __m512 avx_sum = _mm512_setzero_ps();
    for (int i = 0;i < num;i += 16) {
        const __m512 a512 = _mm512_loadu_ps((double*)&vec1[i]);
        const __m512 b512 = _mm512_loadu_ps((double*)&vec2[i]);
        avx_sum = _mm512_fmadd_ps(a512, b512, avx_sum);
    }

    float __attribute__((aligned(32))) out[16] = {};
    _mm512_storeu_ps(out, avx_sum);
    float sum = 0;
    for (int i = 0;i < 16;i++) {
        sum += out[i];
    }
    return sum;
}

test.c

#include<stdio.h>
#include "vector.h"

int main(void)
{
    float a[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
    float b[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
    printf("%f\n", sumProduct(a, b, 16));
    return 0;
}

以下のコマンドでビルドします。

gcc -O2 -mavx512f vector.c -c  -o vector
gcc -O0 -mavx512f test.c -o test vector
./test
408.000000

参考サイト

https://colfaxresearch.com/knl-avx512/