Description
I've noticed EJML matrix multiplication is more than 20 times slower than Pytorch or Numpy. Below I have some test code to reproduce the situation. Am I doing something wrong and what can I do to achieve Python results in Java? Thank you!
public static void main(String[] args) {
var mat1 = fill(new double[15][32]);
var mat2 = fill(new double[32][25600]);
var mask2dMat = new SimpleMatrix(mat1);
var proto2dMat = new SimpleMatrix(mat2);
var ts = System.currentTimeMillis();
mask2dMat.mult(proto2dMat); // multiply
System.out.println(System.currentTimeMillis() - ts); // 20ms on my PC
}
private static double[][] fill(double[][] fMat) {
for (double[] row : fMat) {
for (int i = 0; i < row.length; i++) {
row[i] = ThreadLocalRandom.current().nextFloat();
}
}
return fMat;
}
VS
mat1 = torch.randn(15, 32)
mat2 = torch.randn(32, 25600)
timestamp = int(time.time() * 1000)
mat1 @ mat2 # multiply
print(int(time.time() * 1000) - timestamp) # 0ms on same PC
Thank you.