This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ceb50a2d21 [SYSTEMDS-3665] New rewrite for mmult-add expressions
ceb50a2d21 is described below

commit ceb50a2d2175390267796e0cfd8620ca251c1e3d
Author: ReneEnjilian <[email protected]>
AuthorDate: Sat Jan 20 01:00:12 2024 +0100

    [SYSTEMDS-3665] New rewrite for mmult-add expressions
    
    A%*%B + A%*%C -> A%*%(B+C) iff A, B, and C dense and the target
    expression reduces the number of floating points operations.
    
    Closes #1986.
---
 .../RewriteAlgebraicSimplificationDynamic.java     |  68 +++++++++----
 .../rewrite/RewriteDistributiveMatrixMultTest.java | 107 +++++++++++++++++++++
 .../rewrite/RewriteDistributiveMatrixMult.R        |  41 ++++++++
 .../rewrite/RewriteDistributiveMatrixMult.dml      |  32 ++++++
 4 files changed, 231 insertions(+), 17 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e181c60a78..3e1c498f01 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -176,6 +176,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifyScalarMatrixMult(hop, hi, i);        
//e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix
                        hi = simplifyMatrixMultDiag(hop, hi, i);          
//e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 
                        hi = simplifyDiagMatrixMult(hop, hi, i);          
//e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
+                       hi = simplifyDistributiveMatrixMult(hop, hi, i);  
//e.g., (A%*%B)+(A%*%C) -> A%*%(B+C)
                        hi = simplifySumDiagToTrace(hi);                  
//e.g., sum(diag(X)) -> trace(X); if col vector
                        hi = simplifyLowerTriExtraction(hop, hi, i);      
//e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
                        hi = simplifyConstantCumsum(hop, hi, i);          
//e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n)
@@ -1137,46 +1138,79 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                
                return hi;
        }
-       
-       private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos)
-       {
-               if( hi instanceof ReorgOp && 
((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V
+
+       private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
+               if(hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == 
ReOrgOp.DIAG && hi.getDim2() == 1) //diagM2V
                {
                        Hop hi2 = hi.getInput().get(0);
-                       if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y
+                       if(HopRewriteUtils.isMatrixMultiply(hi2)) //X%*%Y
                        {
                                Hop left = hi2.getInput().get(0);
                                Hop right = hi2.getInput().get(1);
-                               
+
                                //create new operators (incl refresh size 
inside for transpose)
                                ReorgOp trans = 
HopRewriteUtils.createTranspose(right);
                                BinaryOp mult = 
HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
                                AggUnaryOp rowSum = 
HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
-                               
+
                                //rehang new subdag under parent node
                                HopRewriteUtils.replaceChildReference(parent, 
hi, rowSum, pos);
                                HopRewriteUtils.cleanupUnreferenced(hi, hi2);
-                               
+
                                hi = rowSum;
                                LOG.debug("Applied simplifyDiagMatrixMult");
-                       }       
+                       }
                }
-               
+
                return hi;
        }
-       
-       private static Hop simplifySumDiagToTrace(Hop hi)
-       {
-               if( hi instanceof AggUnaryOp ) 
+
+       private static Hop simplifyDistributiveMatrixMult(Hop parent, Hop hi, 
int pos) {
+               // A%*%B + A%*%C -> A%*%(B+C)
+               if(HopRewriteUtils.isBinary(hi, OpOp2.PLUS) 
+                       && HopRewriteUtils.isMatrixMultiply(hi.getInput(0))
+                       && HopRewriteUtils.isMatrixMultiply(hi.getInput(1))
+                       && hi.getInput(0).getParent().size() == 1 //single 
consumer
+                       && hi.getInput(1).getParent().size() == 1 //single 
consumer
+                       && hi.getInput(0).getInput(0) == 
hi.getInput(1).getInput(0)) //common A
                {
+                       Hop A = hi.getInput(0).getInput(0);
+                       Hop B = hi.getInput(0).getInput(1);
+                       Hop C = hi.getInput(1).getInput(1);
+                       boolean dense = HopRewriteUtils.isDense(A) 
+                               && HopRewriteUtils.isDense(B) && 
HopRewriteUtils.isDense(C);
+                       //compute floating point and mem bandwidth requirements 
and 
+                       //according for special cases where C might be a column 
vector
+                       long m = A.getDim1(), n = A.getDim2(), l = B.getDim2(), 
o = C.getDim2();
+                       long costOriginal = m * n * l + m * n * o + m * l //FLOP
+                                               + m*n + n*l + n*o + m*l + m*o + 
m*l;  //I/O ABC+intermediates
+                       long costRewrite = n * l + m * n * l              //FLOP
+                                               + m*n + n*l + n*o + n*l + m*l;  
      //I/O ABC+intermediates
+                       //Check that rewrite reduces FLOPs
+                       if(dense && costRewrite < costOriginal) {
+                               Hop BplusC = HopRewriteUtils.createBinary(B, C, 
OpOp2.PLUS);
+                               Hop newHop = 
HopRewriteUtils.createMatrixMultiply(A, BplusC);
+                               if(parent != null) {
+                                       
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
+                                       HopRewriteUtils.cleanupUnreferenced(hi);
+                                       hi = newHop;
+                                       LOG.debug("Applied 
simplifyDistributiveMatrixMult (line " + hi.getBeginLine() + ")");
+                               }
+                       }
+               }
+               return hi;
+       }
+
+       private static Hop simplifySumDiagToTrace(Hop hi) {
+               if(hi instanceof AggUnaryOp) {
                        AggUnaryOp au = (AggUnaryOp) hi;
-                       if( au.getOp()==AggOp.SUM && 
au.getDirection()==Direction.RowCol )      //sum
+                       if(au.getOp() == AggOp.SUM && au.getDirection() == 
Direction.RowCol)    //sum
                        {
                                Hop hi2 = au.getInput().get(0);
-                               if( hi2 instanceof ReorgOp && 
((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
+                               if(hi2 instanceof ReorgOp && ((ReorgOp) 
hi2).getOp() == ReOrgOp.DIAG && hi2.getDim2() == 1) //diagM2V
                                {
                                        Hop hi3 = hi2.getInput().get(0);
-                                       
+
                                        //remove diag operator
                                        
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
                                        
HopRewriteUtils.cleanupUnreferenced(hi2);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
new file mode 100644
index 0000000000..7f40a2bef3
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class RewriteDistributiveMatrixMultTest extends AutomatedTestBase {
+       private static final String TEST_NAME1 = 
"RewriteDistributiveMatrixMult";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR =
+               TEST_DIR + 
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+
+       private static final int rows = 500;
+       private static final int cols = 500;
+       private static final double eps = Math.pow(10, -10);
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+
+       }
+
+       @Test
+       public void testDistributiveMatrixMultNoRewrite() {
+               testRewriteDistributiveMatrixMult(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testDistributiveMatrixMultRewrite() {
+               testRewriteDistributiveMatrixMult(TEST_NAME1, true);
+       }
+
+       private void testRewriteDistributiveMatrixMult(String testname, boolean 
rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] {"-stats", "-args", 
input("A"), input("B"), input("C"), output("R")};
+
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       //create dense matrices so that rewrites are possible
+                       double[][] A = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 7);
+                       double[][] B = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 6);
+                       double[][] C = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 3);
+                       writeInputMatrixWithMTD("A", A, 174522, true);
+                       writeInputMatrixWithMTD("B", B, 174935, true);
+                       writeInputMatrixWithMTD("C", C, 174848, true);
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+
+                       //check matrix mult existence
+                       String ba = "ba+*";
+                       long numMatMul = Statistics.getCPHeavyHitterCount(ba);
+
+                       if(rewrites == true) {
+                               Assert.assertTrue(numMatMul == 1);
+                       }
+                       else {
+                               Assert.assertTrue(numMatMul == 2);
+                       }
+
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R 
b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R
new file mode 100644
index 0000000000..7c7a623fbe
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read command line arguments
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+# Load required libraries
+library("Matrix")
+library("matrixStats")
+
+# Read matrices A, B, and C from Matrix Market format files
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+C = as.matrix(readMM(paste(args[1], "C.mtx", sep="")))
+
+# Perform the matrix operation
+R = (A %*% B) + (A %*% C)
+
+# Write the result matrix R in Matrix Market format
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git 
a/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml 
b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml
new file mode 100644
index 0000000000..fc3a3a8cf2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Load matrices A, B, and C
+A = read($1)
+B = read($2)
+C = read($3)
+
+# Perform the operation
+R = (A %*% B) + (A %*% C)
+
+# Write the result matrix R
+write(R, $4)

Reply via email to