「線形代数の基礎」をJavaで実装してみる2
「線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html
スカラ値の乗算
ベクトル
public Vector multiply(float scalar) { float[] scalars = new float[this.scalars.length]; for (int i = 0;i < this.scalars.length;i++) { scalars[i] = this.scalars[i] * scalar; } return new Vector(scalars); }
行列
public Matrix multiply(float scalar) { float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length]; for (int i = 0;i < this.o2scalars.length;i++) { for (int j = 0;j < this.o2scalars[i].length;j++) { o2scalars[i][j] = this.o2scalars[i][j] * scalar; } } return new Matrix(o2scalars); }
各要素にスカラ値を掛けます。
Vector v1 = new Vector(new float[] {1, 2, 3}); Vector v2 = v1.multiply(10); System.out.println(v2); Matrix m1 = new Matrix(new float[][] { {1, 2, 3}, {4, 5, 6}, }); Matrix m2 = m1.multiply(10); System.out.println(m2);
実行結果
[10.0, 20.0, 30.0] 2 x 3 | 10.0| 20.0| 30.0| | 40.0| 50.0| 60.0|
次はベクトルの内積です
public float innerProduct(Vector object) { if (this.isVertical || !object.isVertical) { throw new RuntimeException("vertical error"); } float sum = 0; for (int i = 0;i < scalars.length;i++) { sum += scalars[i] * object.scalars[i]; } return sum; }
内積は横ベクトルと縦ベクトルの積の場合可能です。
Vector v10 = new Vector(new float[] {1, 2, 3}, false); Vector v11 = new Vector(new float[] {4, 5, 6}, true); float result = v10.innerProduct(v11); System.out.println(result);
実行結果
32.0
こちらが行列積になります。
public Matrix matrixMultiplication(Matrix object) { final float[][] our = o2scalars; float[][] newScalars = new float[our.length][our[0].length]; for (int i = 0;i < our.length;i++) { for (int j = 0;j < our[0].length;j++) { float sum = 0; for (int k = 0;k < our.length;k++) { sum += our[i][k] * object.o2scalars[k][j]; } newScalars[i][j] = sum; } } return new Matrix(newScalars); }
Matrix m10 = new Matrix(new float[][] { {1, 2}, {3, 4}, }); Matrix m11 = new Matrix(new float[][] { {5, 6}, {7, 8}, }); System.out.println(m10.matrixMultiplication(m11));
結果
2 x 2 | 19.0| 22.0| | 43.0| 50.0|
積算の実装については以上です。
「線形代数の基礎」をJavaで実装してみる
「線形代数の基礎」はこちらのページです。 https://tutorials.chainer.org/ja/05_Basics_of_Linear_Algebra.html
テンソル
public class Tensor { protected final int order; public Tensor(int order) { this.order = order; } }
public class LinearAlgebraTest { public static void main(String[] args) { Tensor o1 = new Tensor(1); // 1階のテンソル Tensor o2 = new Tensor(2); // 2階のテンソル } }
order は N階のテンソルを表します。 Java的にはN次元の配列ということになります。
ベクトル
ベクトルクラスを定義します。
public class Vector extends Tensor { private final float[] scalars; public Vector(float[] scalars) { super(1); // ベクトルは1階のテンソル this.scalars = scalars; } }
ベクトルはテンソルを継承して、order は 1
固定です。
加算を実装します。
Vectorに以下を追加
public Vector add(Vector object) { float[] scalars = new float[this.scalars.length]; for (int i = 0;i < this.scalars.length;i++) { scalars[i] = this.scalars[i] + object.scalars[i]; } return new Vector(scalars); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("["); if (scalars.length > 0) { sb.append(scalars[0]); for (int i = 1;i < scalars.length;i++) { sb.append(", ").append(scalars[i]); } } sb.append("]"); return sb.toString(); }
呼び出し
Vector v1 = new Vector(new float[]{1, 2, 3}); Vector v2 = new Vector(new float[]{4, 5, 6}); System.out.println(v1); System.out.println(v2); Vector v3 = v1.add(v2); System.out.println(v3);
実行結果
[1.0, 2.0, 3.0] [4.0, 5.0, 6.0] [5.0, 7.0, 9.0]
行列
こんな感じの実装にしてみます。
public class Matrix extends Tensor { private final float[][] o2scalars; public Matrix(float[][] o2scalars) { super(2); // 行列は2階のテンソル this.o2scalars = o2scalars; } }
加算を実装してみます。
@Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(o2scalars.length).append(" x ").append(o2scalars[0].length).append("\n"); for (int i = 0;i < o2scalars.length;i++) { sb.append("|"); for (int j = 0;j < o2scalars[i].length;j++) { sb.append(String.format("% 5.1f|", o2scalars[i][j])); } sb.append("\n"); } return sb.toString(); } public Matrix add(Matrix object) { float[][] o2scalars = new float[this.o2scalars.length][this.o2scalars[0].length]; for (int i = 0;i < this.o2scalars.length;i++) { for (int j = 0;j < this.o2scalars[i].length;j++) { o2scalars[i][j] = this.o2scalars[i][j] + object.o2scalars[i][j]; } } return new Matrix(o2scalars); }
Matrix m1 = new Matrix(new float[][] { {1,2,3}, {4,5,6}, }); Matrix m2 = new Matrix(new float[][] { {7,8,9}, {10,11,12}, }); System.out.println(m1); System.out.println(m2); Matrix m3 = m1.add(m2); System.out.println(m3);
実行結果です。
2 x 3 | 1.0| 2.0| 3.0| | 4.0| 5.0| 6.0| 2 x 3 | 7.0| 8.0| 9.0| | 10.0| 11.0| 12.0| 2 x 3 | 8.0| 10.0| 12.0| | 14.0| 16.0| 18.0|
次回は行列の積を実装してみたいと思います。
近似計算の比較
計算をするときいくつかの処理を高速化のために近似値計算で済ます方法があります。 こちらのブログで計算方法が紹介されているので拝借します。 https://martin.ankerl.com/2007/10/04/optimized-pow-approximation-for-java-and-c-c/
package math; import org.apache.commons.math3.util.FastMath; public class PowTest1 { public static void main(String[] args) { for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += Math.pow(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "Math.pow", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += FastMath.pow(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.pow", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += pow1(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "pow1", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += pow2(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "pow2", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += pow3(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "pow3", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 3;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += pow4(j, 2); } System.out.format("%-13s sum= %8.0f %dms\n", "pow4", sum, (System.currentTimeMillis() - start)); } } public static double pow1(final double a, final double b) { final int x = (int) (Double.doubleToLongBits(a) >> 32); final int y = (int) (b * (x - 1072632447) + 1072632447); return Double.longBitsToDouble(((long) y) << 32); } public static double pow2(final double a, final double b) { final long tmp = Double.doubleToLongBits(a); final long tmp2 = (long)(b * (tmp - 4606921280493453312L)) + 4606921280493453312L; return Double.longBitsToDouble(tmp2); } public static double pow3(final double a, final double b) { final double x = (Double.doubleToLongBits(a) >> 32); final long tmp2 = (long) (1512775 * (x - 1072632447) / 1512775 * b + (1072693248 - 60801)); return Double.longBitsToDouble(tmp2 << 32); } public static double pow4(final double a, final double b) { final int tmp = (int) (Double.doubleToLongBits(a) >> 32); final int tmp2 = (int) (b * (tmp - 1072632447) + 1072632447); return Double.longBitsToDouble(((long) tmp2) << 32); } }
macOS Oracle JDK 12.0.2で実行しました。
Math.pow sum= 333332833333127550 7ms Math.pow sum= 333332833333127550 5ms Math.pow sum= 333332833333127550 4ms FastMath.pow sum= 333332833333127550 165ms FastMath.pow sum= 333332833333127550 148ms FastMath.pow sum= 333332833333127550 144ms pow1 sum= 332595653188566270 10ms pow1 sum= 332595653188566270 9ms pow1 sum= 332595653188566270 9ms pow2 sum= 332595653188566270 9ms pow2 sum= 332595653188566270 7ms pow2 sum= 332595653188566270 8ms pow3 sum= 332595653188566270 16ms pow3 sum= 332595653188566270 16ms pow3 sum= 332595653188566270 15ms pow4 sum= 332595653188566270 10ms pow4 sum= 332595653188566270 10ms pow4 sum= 332595653188566270 10ms
Math.powが最速でした。FastMath.powはかなり遅いです。 pow1 ~ pow4は計算結果は変わらないですが、速度に差が出ています。
次はexpです。
package math; import org.apache.commons.math3.util.FastMath; public class ExpTest1 { public static void main(String[] args) { for (int i = 0;i < 5;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += Math.exp(j * 0.0001f); } System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 5;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += FastMath.exp(j * 0.0001f); } System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.exp", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 5;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += exp1(j * 0.0001f); } System.out.format("%-13s sum= %8.0f %dms\n", "exp1", sum, (System.currentTimeMillis() - start)); } } public static double exp1(double val) { final long tmp = (long) (1512775 * val + (1072693248 - 60801)); return Double.longBitsToDouble(tmp << 32); } }
結果です。
Math.exp sum= 268797601492947300000000000000000000000000000000 15ms Math.exp sum= 268797601492947300000000000000000000000000000000 24ms Math.exp sum= 268797601492947300000000000000000000000000000000 23ms Math.exp sum= 268797601492947300000000000000000000000000000000 23ms Math.exp sum= 268797601492947300000000000000000000000000000000 23ms FastMath.exp sum= 268797601492947300000000000000000000000000000000 45ms FastMath.exp sum= 268797601492947300000000000000000000000000000000 30ms FastMath.exp sum= 268797601492947300000000000000000000000000000000 30ms FastMath.exp sum= 268797601492947300000000000000000000000000000000 31ms FastMath.exp sum= 268797601492947300000000000000000000000000000000 31ms exp1 sum= 267998235537123060000000000000000000000000000000 9ms exp1 sum= 267998235537123060000000000000000000000000000000 9ms exp1 sum= 267998235537123060000000000000000000000000000000 8ms exp1 sum= 267998235537123060000000000000000000000000000000 8ms exp1 sum= 267998235537123060000000000000000000000000000000 9ms
今回はexp1が最速でした。FastMath.expの結果は悪く無いですが、一番遅かったです。
最後はsqrtです
package math; import org.apache.commons.math3.util.FastMath; public class SqrtTest1 { public static void main(String[] args) { for (int i = 0;i < 4;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += Math.sqrt(j); } System.out.format("%-13s sum= %8.0f %dms\n", "Math.exp", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 4;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += FastMath.sqrt(j); } System.out.format("%-13s sum= %8.0f %dms\n", "FastMath.sqrt", sum, (System.currentTimeMillis() - start)); } for (int i = 0;i < 4;i++) { long start; double sum = 0; start = System.currentTimeMillis(); for (int j = 0;j < 1000000;j++) { sum += sqrt1(j); } System.out.format("%-13s sum= %8.0f %dms\n", "sqrt1", sum, (System.currentTimeMillis() - start)); } } public static double sqrt1(final double a) { final long x = Double.doubleToLongBits(a) >> 32; double y = Double.longBitsToDouble((x + 1072632448) << 31); return y; } }
結果です。
Math.exp sum= 666666166 7ms Math.exp sum= 666666166 9ms Math.exp sum= 666666166 5ms Math.exp sum= 666666166 3ms FastMath.sqrt sum= 666666166 10ms FastMath.sqrt sum= 666666166 8ms FastMath.sqrt sum= 666666166 7ms FastMath.sqrt sum= 666666166 4ms sqrt1 sum= 666888545 4ms sqrt1 sum= 666888545 4ms sqrt1 sum= 666888545 4ms sqrt1 sum= 666888545 4ms
最速はMath.expでした。 アルゴリズム的には似通ったものになるのか差はほとんどありませんでした。
GraalVMのnative-imageを試す
https://www.graalvm.org/docs/getting-started/
こちらの手順に沿ってダウンロードして、エイリアスの設定
alias java8=~/graalvm-ce-19.2.0.1/Contents/Home/bin/java alias javac8=~/graalvm-ce-19.2.0.1/Contents/Home/bin/javac
以下のソースで実験
// VectorTest3.java public class VectorTest3 { private final static int NUM = 100 * 1000 * 1000; private static float dotProduct(float[] vec_a, float[] vec_b) { float sum = 0; for (int i = 0; i < vec_a.length; i++) { sum += vec_a[i] * vec_b[i]; } return sum; } private static void bench1() { float[] vec_a = new float[NUM]; float[] vec_b = new float[NUM]; for (int i = 0;i < NUM;i++) { vec_a[i] = (float)Math.random(); vec_b[i] = (float)Math.random(); } long start = System.currentTimeMillis(); float sum = dotProduct(vec_a, vec_b); System.out.format("bench1 - %d ms\n", (System.currentTimeMillis() - start)); } public static void main(String[] args) { for(int i = 0;i < 10;i++) { bench1(); } } }
前回のpanamaビルドで実行
$ javac14 src/main/java/VectorTest3.java \ -d out/ $ java14 -cp out/ VectorTest3 bench1 - 137 ms bench1 - 124 ms bench1 - 130 ms bench1 - 112 ms bench1 - 121 ms bench1 - 130 ms bench1 - 120 ms bench1 - 120 ms bench1 - 122 ms bench1 - 117 ms
最速で117msです。 GraalVMの方では
$ javac8 src/main/java/VectorTest3.java \ -d out/ $ java8 -cp out/ VectorTest3 bench1 - 131 ms bench1 - # # A fatal error has been detected by the Java Runtime Environment: # # Internal Error (deoptimization.cpp:808), pid=62160, tid=0x0000000000001603 # fatal error: java/lang/Long$LongCache must be initialized
クラッシュ...
native-imageをします。
# インストール $ ~/graalvm-ce-19.2.0.1/Contents/Home/bin/gu install native-image $ ~/graalvm-ce-19.2.0.1/Contents/Home/bin/native-image -cp out/ VectorTest3 Build on Server(pid: 44023, port: 50514) [vectortest3:44023] classlist: 143.75 ms [vectortest3:44023] (cap): 2,228.31 ms [vectortest3:44023] setup: 3,358.35 ms [vectortest3:44023] (typeflow): 2,793.14 ms [vectortest3:44023] (objects): 2,172.78 ms [vectortest3:44023] (features): 170.70 ms [vectortest3:44023] analysis: 5,226.55 ms [vectortest3:44023] (clinit): 96.69 ms [vectortest3:44023] universe: 324.64 ms [vectortest3:44023] (parse): 432.65 ms [vectortest3:44023] (inline): 969.11 ms [vectortest3:44023] (compile): 4,320.70 ms [vectortest3:44023] compile: 6,017.41 ms [vectortest3:44023] image: 490.40 ms [vectortest3:44023] write: 184.50 ms [vectortest3:44023] [total]: 15,869.97 ms $ ls -la vectortest3 -rwxr-xr-x 1 tak staff 4496880 Oct 5 11:00 vectortest3
約4MBの実行ファイルが生成されました。 実行します。
$ ./vectortest3 bench1 - 127 ms bench1 - 130 ms bench1 - 125 ms bench1 - 126 ms bench1 - 126 ms bench1 - 127 ms bench1 - 136 ms bench1 - 126 ms bench1 - 129 ms bench1 - 129 ms
なんか遅くなっている気がします。
$ otool -L vectortest3 vectortest3: /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1281.0.0) /System/Library/Frameworks/CoreFoundation.framework/Versions/A/CoreFoundation (compatibility version 150.0.0, current version 1670.10.0) /usr/lib/libz.1.dylib (compatibility version 1.0.0, current version 1.2.11)
静的リンクはlibzが必要そうです。
objdump -d vectortest3 | grep vfmadd
vfmaddでgrepしましたが特に使われていないのでベクトルができていない模様
Chainer 6.4.0 リリース
6.4.0
https://github.com/chainer/chainer/releases/tag/v6.4.0
現在は7.0.0の開発がメインのようで主に7.0.0からのバックポートとなっています。
Java Vector API を試す
https://nowokay.hatenablog.com/entry/2019/09/05/015537
こちらの記事を参考にビルドします。
CのコードとJavaのコードを比較していきます。
# vector.h float dotProduct512(float* vec1, float* vec2, int num); float dotProduct256(float* vec1, float* vec2, int num); float dotProduct(float* vec1, float* vec2, int num);
#vector.c #include "vector.h" #include <immintrin.h> float dotProduct512(float* vec1, float* vec2, int num) { __m512 avx_sum = _mm512_setzero_ps(); int limit = num - num % 16; for (int i = 0;i < limit;i += 16) { const __m512 a512 = _mm512_loadu_ps((float*)&vec1[i]); const __m512 b512 = _mm512_loadu_ps((float*)&vec2[i]); avx_sum = _mm512_fmadd_ps(a512, b512, avx_sum); } float __attribute__((aligned(32))) out[16] = {}; _mm512_storeu_ps(out, avx_sum); float sum = 0; for (int i = 0;i < 16;i++) { sum += out[i]; } for (int i = limit;i < num;i++) { sum += vec1[i] * vec2[i]; } return sum; } float dotProduct256(float* vec1, float* vec2, int num) { __m256 avx_sum = _mm256_setzero_ps(); int limit = num - num % 8; for (int i = 0;i < limit;i += 8) { const __m256 a256 = _mm256_loadu_ps((float*)&vec1[i]); const __m256 b256 = _mm256_loadu_ps((float*)&vec2[i]); avx_sum = _mm256_fmadd_ps(a256, b256, avx_sum); } float __attribute__((aligned(32))) out[16] = {}; _mm256_store_ps(out, avx_sum); float sum = 0; for (int i = 0;i < 8;i++) { sum += out[i]; } for (int i = limit;i < num;i++) { sum += vec1[i] * vec2[i]; } return sum; } float dotProduct(float* vec1, float* vec2, int num) { float sum = 0; for (int i = 0;i < num;i++) { sum += vec1[i] * vec2[i]; } return sum; }
# vec_test.c #include <stdio.h> #include <stdlib.h> #include <sys/time.h> #include "vector.h" #define NUM 100 * 1000 * 1000 unsigned long getMicroSec() { struct timespec time1; clock_gettime(CLOCK_REALTIME,&time1); unsigned long micros = time1.tv_sec * 1000000; micros += time1.tv_nsec / 1000; return micros; } void bench1() { float *vec_a, *vec_b; vec_a = (float*)malloc(sizeof(float) * NUM); vec_b = (float*)malloc(sizeof(float) * NUM); for(int i = 0;i < NUM;i++) { vec_a[i] = ((float)rand()) / NUM; vec_b[i] = ((float)rand()) / NUM; } unsigned long start = getMicroSec(); float sum = dotProduct(vec_a, vec_b, NUM); printf("bench1 - %lu ms\n", (getMicroSec() - start) / 1000); free(vec_a); free(vec_b); } void bench2() { float *vec_a, *vec_b; vec_a = (float*)malloc(sizeof(float) * NUM); vec_b = (float*)malloc(sizeof(float) * NUM); for(int i = 0;i < NUM;i++) { vec_a[i] = ((float)rand()) / NUM; vec_b[i] = ((float)rand()) / NUM; } unsigned long start = getMicroSec(); float sum = dotProduct256(vec_a, vec_b, NUM); printf("bench2 - %lu ms\n", (getMicroSec() - start) / 1000); free(vec_a); free(vec_b); } int main(void) { bench1(); bench2(); return 0; }
以下のコマンドでビルドします。
gcc -O2 -mavx512f vec_test.c -o vec_test vector gcc -O2 vec_test.c -o vec_test vector
実行結果です。
$ ./vec_test bench1 - 135 ms bench2 - 52 ms
// VectorTest2.java import jdk.incubator.vector.*; public class VectorTest2 { private final static int NUM = 100 * 1000 * 1000; private static float dotProduct256(float[] vec_a, float[] vec_b) { var SP = FloatVector.SPECIES_256; int limit = vec_a.length - vec_a.length % 8; var fv_sum = FloatVector.fromValues(SP, 0, 0, 0, 0, 0, 0, 0, 0); for (int i = 0; i < limit; i += 8) { var fv_a = FloatVector.fromArray(SP, vec_a, i); var fv_b = FloatVector.fromArray(SP, vec_b, i); fv_sum = fv_a.fma(fv_b, fv_sum); } float[] outArray = new float[8]; fv_sum.intoArray(outArray, 0); float sum = 0; for (float f: outArray) { sum += f; } for (int i = limit; i < vec_a.length; i += 1) { sum += vec_a[i] * vec_b[i]; } return sum; } private static float dotProduct(float[] vec_a, float[] vec_b) { float sum = 0; for (int i = 0; i < vec_a.length; i++) { sum += vec_a[i] * vec_b[i]; } return sum; } private static void bench1() { float[] vec_a = new float[NUM]; float[] vec_b = new float[NUM]; for (int i = 0;i < NUM;i++) { vec_a[i] = (float)Math.random(); vec_b[i] = (float)Math.random(); } long start = System.currentTimeMillis(); float sum = dotProduct(vec_a, vec_b); System.out.format("bench1 - %d ms\n", (System.currentTimeMillis() - start)); } private static void bench2() { float[] vec_a = new float[NUM]; float[] vec_b = new float[NUM]; for (int i = 0;i < NUM;i++) { vec_a[i] = (float)Math.random(); vec_b[i] = (float)Math.random(); } long start = System.currentTimeMillis(); float sum = dotProduct256(vec_a, vec_b); System.out.format("bench2 - %d ms\n", (System.currentTimeMillis() - start)); } public static void main(String[] args) { for(int i = 0;i < 20;i++) { bench1(); } for(int i = 0;i < 20;i++) { bench2(); } } }
ビルドして実行します。
$ javac14 src/main/java/VectorTest2.java \ --add-modules jdk.incubator.vector \ -d out/ $ java14 -cp out/ VectorTest2 bench1 - 114 ms bench1 - 127 ms bench1 - 116 ms bench1 - 128 ms bench1 - 108 ms bench1 - 140 ms bench1 - 117 ms bench1 - 114 ms bench1 - 112 ms bench1 - 109 ms bench1 - 144 ms bench1 - 120 ms bench1 - 120 ms bench1 - 110 ms bench1 - 117 ms bench1 - 107 ms bench1 - 107 ms bench1 - 110 ms bench1 - 107 ms bench1 - 109 ms bench2 - 447 ms bench2 - 125 ms bench2 - 163 ms bench2 - 110 ms bench2 - 117 ms bench2 - 117 ms bench2 - 117 ms bench2 - 116 ms bench2 - 123 ms bench2 - 116 ms bench2 - 129 ms bench2 - 116 ms bench2 - 116 ms bench2 - 116 ms bench2 - 117 ms bench2 - 115 ms bench2 - 117 ms bench2 - 116 ms bench2 - 118 ms bench2 - 116 ms
VectorAPIを使用しない場合107msが最速ですが、使用した場合116msが最速です。 なぜ使用すると遅くなるのかは不明ですが、まだまだ最適化が必要なのでしょう。
AVX512をEC2で試す
今回はEC2のc5.largeインスタンスを使います。
CPU情報は
$ cat /proc/cpuinfo processor : 0 vendor_id : GenuineIntel cpu family : 6 model : 85 model name : Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz stepping : 4 microcode : 0x200005e cpu MHz : 3408.548 cache size : 25344 KB physical id : 0 siblings : 2 core id : 0 cpu cores : 1 apicid : 0 initial apicid : 0 fpu : yes fpu_exception : yes cpuid level : 13 wp : yes flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds bogomips : 5999.99 clflush size : 64 cache_alignment : 64 address sizes : 46 bits physical, 48 bits virtual power management: processor : 1 vendor_id : GenuineIntel cpu family : 6 model : 85 model name : Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz stepping : 4 microcode : 0x200005e cpu MHz : 3411.384 cache size : 25344 KB physical id : 0 siblings : 2 core id : 0 cpu cores : 1 apicid : 1 initial apicid : 1 fpu : yes fpu_exception : yes cpuid level : 13 wp : yes flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds bogomips : 5999.99 clflush size : 64 cache_alignment : 64 address sizes : 46 bits physical, 48 bits virtual power management:
c5の上位インスタンスではIntel DL Boostに対応していますが、下位インスタンスでは対応していません。 2019/7/24のIntel AIのブログです。
実行のコードです。
vector.h
float sumProduct(float* vec1, float* vec2, int num);
vector.c
#include "vector.h" #include <immintrin.h> float sumProduct(float* vec1, float* vec2, int num) { __m512 avx_sum = _mm512_setzero_ps(); for (int i = 0;i < num;i += 16) { const __m512 a512 = _mm512_loadu_ps((double*)&vec1[i]); const __m512 b512 = _mm512_loadu_ps((double*)&vec2[i]); avx_sum = _mm512_fmadd_ps(a512, b512, avx_sum); } float __attribute__((aligned(32))) out[16] = {}; _mm512_storeu_ps(out, avx_sum); float sum = 0; for (int i = 0;i < 16;i++) { sum += out[i]; } return sum; }
test.c
#include<stdio.h> #include "vector.h" int main(void) { float a[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; float b[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; printf("%f\n", sumProduct(a, b, 16)); return 0; }
以下のコマンドでビルドします。
gcc -O2 -mavx512f vector.c -c -o vector gcc -O0 -mavx512f test.c -o test vector ./test 408.000000