This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ea00f88b60 [SYSTEMDS-3168] Matrix Multiplication Transposed Kernels
ea00f88b60 is described below

commit ea00f88b60c2f14207c4dcf537533373f24b0516
Author: Elman Jahangiri <[email protected]>
AuthorDate: Sun Mar 29 12:45:35 2026 +0200

    [SYSTEMDS-3168] Matrix Multiplication Transposed Kernels
    
    Optimize dense matrix mult for transposed inputs
    
    This introduces specialized kernels for dense matrix multiplication
    involving transposed inputs (t(A)%*%B, A%*%t(B), t(A)%*%t(B)).
    Previously, these operations required an explicit intermediate transpose
    step, which caused unnecessary runtime.
    
    The new kernels perform the operations in-place or using
    tiled-transposition,
    avoiding the full allocation cost.
    
    Performance benchmarks on 100x100 dense matrices show significant
    speedups especially for t(A)%*%B
    and t(A)%*%t(B) and can be tested with higher dimensions.
    
    Closes #2425.
---
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  10 +-
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |  36 ++---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 161 +++++++++++++++++++++
 .../MatrixMultTransposedPerformanceTest.java       | 107 ++++++++++++++
 .../matrixmult/MatrixMultTransposedTest.java       |  91 ++++++++++++
 5 files changed, 382 insertions(+), 23 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index f2d26be535..a02a550350 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -335,12 +335,12 @@ public class OptimizerUtils
 
        public static boolean AUTO_GPU_CACHE_EVICTION = true;
 
-    /**
-     * Boolean specifying if relational algebra rewrites are allowed (e.g. 
Selection Pushdowns).
-     */
-    public static boolean ALLOW_RA_REWRITES = false;
+       /**
+        * Boolean specifying if relational algebra rewrites are allowed (e.g. 
Selection Pushdowns).
+        */
+       public static boolean ALLOW_RA_REWRITES = false;
 
-    //////////////////////
+       //////////////////////
        // Optimizer levels //
        //////////////////////
 
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 927433c8c7..49c85191d2 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -103,27 +103,27 @@ public class ProgramRewriter{
                                _sbRuleSet.add(  new 
RewriteMergeBlockSequence()                 ); //dependency: remove branches, 
remove for-loops
                        if(OptimizerUtils.ALLOW_COMPRESSION_REWRITE)
                                _sbRuleSet.add(      new 
RewriteCompressedReblock()              ); // Compression Rewrite
-                       if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
-                               _sbRuleSet.add(  new 
RewriteSplitDagUnknownCSVRead()             ); //dependency: reblock, merge 
blocks
-                       if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS && 
-                               
ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)
 )
-                               _sbRuleSet.add(  new 
RewriteSplitDagDataDependentOperators()     ); //dependency: merge blocks
-                       if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
+                       if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
+                               _sbRuleSet.add(  new 
RewriteSplitDagUnknownCSVRead()             ); //dependency: reblock, merge 
blocks
+                       if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS && 
+                               
ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)
 )
+                               _sbRuleSet.add(  new 
RewriteSplitDagDataDependentOperators()     ); //dependency: merge blocks
+                       if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
                                _sbRuleSet.add(  new 
RewriteForLoopVectorization()               ); //dependency: reblock (reblockop)
-                       _sbRuleSet.add( new 
RewriteInjectSparkLoopCheckpointing(true)        ); //dependency: reblock 
(blocksizes)
-                       if( OptimizerUtils.ALLOW_CODE_MOTION )
-                               _sbRuleSet.add(  new 
RewriteHoistLoopInvariantOperations()       ); //dependency: vectorize, but 
before inplace
-                       if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE )
-                               _sbRuleSet.add(  new 
RewriteMarkLoopVariablesUpdateInPlace()     );
-                       if( LineageCacheConfig.getCompAssRW() )
-                               _sbRuleSet.add(  new MarkForLineageReuse()      
                 );
-            if( OptimizerUtils.ALLOW_RA_REWRITES )
-                _sbRuleSet.add(  new RewriteRaPushdown()                       
  );
-                       _sbRuleSet.add(      new 
RewriteRemoveTransformEncodeMeta()          );
-                       _dagRuleSet.add( new RewriteNonScalarPrint()            
             );
+                       _sbRuleSet.add( new 
RewriteInjectSparkLoopCheckpointing(true)        ); //dependency: reblock 
(blocksizes)
+                       if( OptimizerUtils.ALLOW_CODE_MOTION )
+                               _sbRuleSet.add(  new 
RewriteHoistLoopInvariantOperations()       ); //dependency: vectorize, but 
before inplace
+                       if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE )
+                               _sbRuleSet.add(  new 
RewriteMarkLoopVariablesUpdateInPlace()     );
+                       if( LineageCacheConfig.getCompAssRW() )
+                               _sbRuleSet.add(  new MarkForLineageReuse()      
                 );
+                       if( OptimizerUtils.ALLOW_RA_REWRITES )
+                               _sbRuleSet.add(  new RewriteRaPushdown()        
                 );
+                       _sbRuleSet.add(      new 
RewriteRemoveTransformEncodeMeta()          );
+                       _dagRuleSet.add( new RewriteNonScalarPrint()            
             );
                        if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE )
                                _sbRuleSet.add( new RewriteJoinReordering() );
-               }
+               }
                
                // DYNAMIC REWRITES (which do require size information)
                if( dynamicRewrites )
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index b638c7771b..9763e4ea57 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -45,6 +45,7 @@ import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
 import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.DenseBlockFP64;
 import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlock;
@@ -1026,6 +1027,166 @@ public class LibMatrixMult
        // optimized matrix mult implementation //
        //////////////////////////////////////////
 
+       public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, 
DenseBlock c, boolean transA, boolean transB, int n, int cd, int rl, int ru, 
int cl, int cu) {
+               // C = A %*% B
+               if (!transA && !transB)
+                       matrixMultDenseDenseMM(a, b, c, n, cd, rl, ru, cl, cu);
+               // C = t(A) %*% B
+               else if (transA && !transB)
+                       multDenseDenseTransA(a, b, c, n, cd, rl, ru, cl, cu);
+               // C = A %*% t(B)
+               else if (!transA && transB)
+                       multDenseDenseTransB(a, b, c, n, cd, rl, ru, cl, cu);
+               // C = t(A) %*% t(B)
+               else if (transA && transB)
+                       multDenseDenseTransATransB(a, b, c, n, cd, rl, ru, cl, 
cu);
+       }
+
+       private static void multDenseDenseTransA(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
+               // process matrices in small blocks for caching
+               final int blocksizeI = 32;
+               final int blocksizeK = 24;
+               final int blocksizeJ = 1024;
+
+               // iterate over block of C rows
+               for (int bi = rl; bi < ru; bi += blocksizeI) {
+                       int bimin = Math.min(ru, bi + blocksizeI);
+
+                       // iterate over blocks of common dimension k
+                       for (int bk = 0; bk < cd; bk += blocksizeK) {
+                               int bkmin = Math.min(cd, bk + blocksizeK);
+
+                               // iterate over blocks of C columns
+                               for (int bj = cl; bj < cu; bj += blocksizeJ) {
+                                       int bjmin = Math.min(cu, bj + 
blocksizeJ);
+                                       int lenJ = bjmin - bj;
+
+                                       // if B is a single contiguous array, 
we skip checks inside the loop
+                                       if (b.isContiguous()) {
+                                               double[] bvals = b.values(0);
+
+                                               int k = bk;
+                                               // process 4 rows of A at the 
same time
+                                               for (; k < bkmin - 3; k += 4) {
+                                                       int bix0 = b.pos(k, bj);
+                                                       int bix1 = b.pos(k+1, 
bj);
+                                                       int bix2 = b.pos(k+2, 
bj);
+                                                       int bix3 = b.pos(k+3, 
bj);
+
+                                                       for (int i = bi; i < 
bimin; i++) {
+                                                               // grab 4 
values from A
+                                                               double val0 = 
a.values(k)[a.pos(k) + i];
+                                                               double val1 = 
a.values(k+1)[a.pos(k+1) + i];
+                                                               double val2 = 
a.values(k+2)[a.pos(k+2) + i];
+                                                               double val3 = 
a.values(k+3)[a.pos(k+3) + i];
+
+                                                               double[] cvals 
= c.values(i);
+                                                               int cix = 
c.pos(i, bj);
+
+                                                               
vectMultiplyAdd4(val0, val1, val2, val3,
+                                                                       bvals, 
cvals,
+                                                                       bix0, 
bix1, bix2, bix3, cix, lenJ);
+                                                       }
+                                               }
+                                               // for the remaining rows
+                                               for (; k < bkmin; k++) {
+                                                       int bix = b.pos(k, bj);
+                                                       for (int i = bi; i < 
bimin; i++) {
+                                                               double val = 
a.values(k)[a.pos(k) + i];
+                                                               if (val != 0) {
+                                                                       
vectMultiplyAdd(val, bvals, c.values(i), bix, c.pos(i, bj), lenJ);
+                                                               }
+                                                       }
+                                               }
+                                       } else {
+                                               for (int k = bk; k < bkmin; 
k++) {
+                                                       for (int i = bi; i < 
bimin; i++) {
+                                                               double val = 
a.values(k)[a.pos(k) + i];
+                                                               if (val != 0) {
+                                                                       
vectMultiplyAdd(val, b.values(k), c.values(i),
+                                                                               
b.pos(k, bj), c.pos(i, bj), lenJ);
+                                                               }
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+
+       private static void multDenseDenseTransB(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
+               // copy small blocks of B into buffer bufB
+               final int blocksizeK = 24;
+               double[] bufB = new double[blocksizeK * (cu - cl)];
+
+               for (int bk = 0; bk < cd; bk += blocksizeK) {
+                       int bkmin = Math.min(cd, bk + blocksizeK);
+                       int bklen = bkmin - bk;
+
+                       // put B into buffer while transposing
+                       for (int j = cl; j < cu; j++) {
+                               double[] bvals = b.values(j);
+                               int bpos = b.pos(j);
+
+                               for (int k = 0; k < bklen; k++) {
+                                       bufB[k * (cu-cl) + (j-cl)] = bvals[bpos 
+ bk + k];
+                               }
+                       }
+
+                       // perform matrix multiplication with buffer
+                       for (int i = rl; i < ru; i++) {
+                               double[] avals = a.values(i);
+                               int apos = a.pos(i);
+                               double[] cvals = c.values(i);
+                               int cix = c.pos(i, cl);
+
+                               for (int k = 0; k < bklen; k++) {
+                                       double val = avals[apos + bk + k];
+                                       if (val != 0) {
+                                               int bufIx = k * (cu-cl);
+                                               vectMultiplyAdd(val, bufB, 
cvals, bufIx, cix, cu - cl);
+                                       }
+                               }
+                       }
+               }
+       }
+
+       private static void multDenseDenseTransATransB(DenseBlock a, DenseBlock 
b, DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
+               // transpose B into temp Block B
+               // use C = t(A) * B from above as helper method
+
+               // allocate Block for transposing B
+               int tB_rows = cd;
+               int tB_cols = cu - cl;
+
+               // allocate new denseBlock for result matrix
+               DenseBlock tB_block = new DenseBlockFP64(new int[] {tB_rows, 
tB_cols});
+               double[] tB = tB_block.values(0);
+
+               // perform transpose from B to tB_block
+               final int BLOCK = 128;
+               for (int bi = cl; bi < cu; bi += BLOCK) {
+                       int bimin = Math.min(cu, bi + BLOCK);
+                       for (int bk = 0; bk < cd; bk += BLOCK) {
+                               int bkmin = Math.min(cd, bk + BLOCK);
+
+                               for (int j = bi; j < bimin; j++) {
+                                       double[] b_vals = b.values(j);
+                                       int b_pos = b.pos(j);
+
+                                       int tB_col_idx = (j - cl);
+
+                                       for (int k = bk; k < bkmin; k++) {
+                                               tB[k * tB_cols + tB_col_idx] = 
b_vals[b_pos + k];
+                                       }
+                               }
+                       }
+               }
+               // reuse our existing method
+               multDenseDenseTransA(a, tB_block, c, n, cd, rl, ru, 0, tB_cols);
+       }
+
+
        private static void matrixMultDenseDense(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru, int cl, int cu) {
                DenseBlock a = m1.getDenseBlock();
                DenseBlock b = m2.getDenseBlock();
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedPerformanceTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedPerformanceTest.java
new file mode 100644
index 0000000000..18d62251a0
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedPerformanceTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrixmult;
+
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.junit.Test;
+
+public class MatrixMultTransposedPerformanceTest {
+       // could be adjusted, as it takes a lot of runtime with higher 
dimensions
+       private final int m = 200;
+       private final int n = 200;
+       private final int k = 200;
+
+       @Test
+       public void testPerf_1_NoTransA_TransB() {
+               System.out.println("Case: C = A %*% t(B)");
+               runTest(false, true);
+               System.out.println();
+       }
+
+       @Test
+       public void testPerf_2_TransA_NoTransB() {
+               System.out.println("Case: C = t(A) %*% B");
+               runTest(true, false);
+               System.out.println();
+       }
+
+       @Test
+       public void testPerf_3_TransA_TransB() {
+               System.out.println("Case: C = t(A) %*% t(B)");
+               runTest(true, true);
+       }
+
+       private void runTest(boolean tA, boolean tB) {
+               int REP = 100;
+
+               // setup Dimensions
+               int rowsA = tA ? k : m;
+               int colsA = tA ? m : k;
+               int rowsB = tB ? n : k;
+               int colsB = tB ? k : n;
+
+               // generate random matrices
+               MatrixBlock A = MatrixBlock.randOperations(rowsA, colsA, 1.0, 
-1, 1, "uniform", 7);
+               MatrixBlock B = MatrixBlock.randOperations(rowsB, colsB, 1.0, 
-1, 1, "uniform", 3);
+               MatrixBlock C = new MatrixBlock(m, n, false);
+               C.allocateDenseBlock();
+
+               for(int i=0; i<50; i++) {
+                       runOldMethod(A, B, tA, tB);
+                       runNewKernel(A, B, C, tA, tB);
+               }
+
+               // Measure Old Method
+               long startTimeOld = System.nanoTime();
+               for(int i = 0; i < REP; i++) {
+                       runOldMethod(A, B, tA, tB);
+               }
+               double avgTimeOld = (System.nanoTime() - startTimeOld) / 1e6 / 
REP;
+
+               // Measure New Kernel
+               double startTimeNew = System.nanoTime();
+               for(int i = 0; i < REP; i++) {
+                       runNewKernel(A, B, C, tA, tB);
+               }
+               double avgTimeNew = (System.nanoTime() - startTimeNew) / 1e6 / 
REP;
+
+               // print results comparison
+               System.out.printf("Old Method: %.3f ms | New Kernel: %.3f 
ms%n", avgTimeOld, avgTimeNew);
+       }
+
+       private void runNewKernel(MatrixBlock A, MatrixBlock B, MatrixBlock C, 
boolean tA, boolean tB) {
+               C.reset();
+               LibMatrixMult.matrixMultDenseDenseMM(A.getDenseBlock(), 
B.getDenseBlock(), C.getDenseBlock(), tA, tB, m, k, 0, m, 0, n);
+       }
+
+       private void runOldMethod(MatrixBlock A, MatrixBlock B, boolean tA, 
boolean tB) {
+               // do transpose if needed
+               MatrixBlock A_in = tA ? LibMatrixReorg.transpose(A) : A;
+               MatrixBlock B_in = tB ? LibMatrixReorg.transpose(B) : B;
+
+               MatrixBlock C = new MatrixBlock(m, n, false);
+               C.allocateDenseBlock();
+
+               LibMatrixMult.matrixMultDenseDenseMM(A_in.getDenseBlock(), 
B_in.getDenseBlock(), C.getDenseBlock(), false,
+                       false, m, k, 0, m, 0, n);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedTest.java
new file mode 100644
index 0000000000..6c18df009f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrixmult/MatrixMultTransposedTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrixmult;
+
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public class MatrixMultTransposedTest {
+
+    // run multiple random scenarios
+    @Test
+    public void testCase_noTransA_TransB() {
+        for(int i=0; i<10; i++) {
+            runTest(false, true);
+        }
+    }
+
+    @Test
+    public void testCase_TransA_NoTransB() {
+        for(int i=0; i<10; i++) {
+            runTest(true, false);
+        }
+    }
+
+    @Test
+    public void testCase_TransA_TransB() {
+        for(int i=0; i<10; i++) {
+            runTest(true, true);
+        }
+    }
+
+    private void runTest(boolean tA, boolean tB) {
+        Random rand = new Random();
+
+        // generate random dimensions between 1 and 300
+        int m = rand.nextInt(300) + 1;
+        int n = rand.nextInt(300) + 1;
+        int k = rand.nextInt(300) + 1;
+
+
+        int rowsA = tA ? k : m;
+        int colsA = tA ? m : k;
+        int rowsB = tB ? n : k;
+        int colsB = tB ? k : n;
+
+        MatrixBlock ma = MatrixBlock.randOperations(rowsA, colsA, 1.0, -1, 1, 
"uniform", 7);
+        MatrixBlock mb = MatrixBlock.randOperations(rowsB, colsB, 1.0, -1, 1, 
"uniform", 3);
+
+        MatrixBlock mc = new MatrixBlock(m, n, false);
+        mc.allocateDenseBlock();
+
+        DenseBlock a = ma.getDenseBlock();
+        DenseBlock b = mb.getDenseBlock();
+        DenseBlock c = mc.getDenseBlock();
+
+        LibMatrixMult.matrixMultDenseDenseMM(a, b, c, tA, tB, n, k, 0, m, 0, 
n);
+
+        mc.recomputeNonZeros();
+
+        // calc true result with existing methods
+        MatrixBlock ma_in = tA ? LibMatrixReorg.transpose(ma) : ma;
+        MatrixBlock mb_in = tB ? LibMatrixReorg.transpose(mb) : mb;
+        MatrixBlock expected = LibMatrixMult.matrixMult(ma_in, mb_in);
+
+        // compare results
+        TestUtils.compareMatrices(expected, mc, 1e-8);
+    }
+}

Reply via email to