This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e607544ea6ab1f2ec1c9f2c8370c38c86c346170
Author: baunsgaard <[email protected]>
AuthorDate: Tue Sep 7 19:14:56 2021 +0200

    [SYSTEMDS-3123] Rewrite c bind 0 Matrix Multiplication
    
    ```
     cbind((X %*% Y), matrix(0, nrow(X), 1))
    
     ->
    
     X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
    ```
    
    This commit contains a rewrite that change the sequences if number
    of rows in X is 2x larger than Y:
    
    This rewrite effects MLogReg in line 215 to not force allocation of the
    large X twice.
---
 .../RewriteAlgebraicSimplificationDynamic.java     |  54 +++++---
 .../compress/workload/WorkloadAnalyzer.java        |  30 ++---
 .../compress/workload/WorkloadAlgorithmTest.java   |  34 ++---
 .../rewrite/RewriteMMCBindZeroVector.java          | 145 +++++++++++++++++++++
 .../compress/workload/WorkloadAnalysisMLogReg.dml  |  13 +-
 .../RewritMMCBindZeroVectorOp.dml}                 |  26 +---
 6 files changed, 229 insertions(+), 73 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 63a05a4..0b91e5d 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -26,6 +26,20 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 
+import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
+import org.apache.sysds.common.Types.OpOp4;
+import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.BinaryOp;
@@ -41,22 +55,8 @@ import org.apache.sysds.hops.QuaternaryOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.common.Types.OpOp2;
-import org.apache.sysds.common.Types.OpOp3;
-import org.apache.sysds.common.Types.OpOp4;
-import org.apache.sysds.common.Types.OpOpDG;
-import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.lops.MapMultChain.ChainType;
 import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
 
 /**
  * Rule: Algebraic Simplifications. Simplifies binary expressions
@@ -109,7 +109,6 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
                if( root == null )
                        return root;
-               
                //one pass rewrite-descend (rewrite created pattern)
                rule_AlgebraicSimplification( root, false );
                
@@ -197,6 +196,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifyNnzComputation(hop, hi, i);          
//e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
                        hi = simplifyNrowNcolComputation(hop, hi, i);     
//e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency
                        hi = simplifyTableSeqExpand(hop, hi, i);          
//e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, 
ignore=false, cast=true)
+                       hi = simplyfyMMCBindZeroVector(hop, hi, i);       
//e.g.. cbind((X %*% Y), matrix (0, nrow(X), 1)) -> X %*% (cbind(Y, matrix(0, 
nrow(Y), 1))) if nRows of x is larger than nCols of y
                        if( OptimizerUtils.ALLOW_OPERATOR_FUSION )
                                foldMultipleMinMaxOperations(hi);             
//e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y)
                        
@@ -2796,4 +2796,28 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                
                return hi;
        }
+
+       private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int 
pos) {
+
+               // cbind((X %*% Y), matrix(0, nrow(X), 1)) ->
+               // X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+               // if nRows of x is larger than nCols of y
+               // rewrite used in MLogReg first level loop.
+               
+               if(HopRewriteUtils.isBinary(hi, OpOp2.CBIND) && 
HopRewriteUtils.isMatrixMultiply(hi.getInput(0)) &&
+                       
HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(1), 0) && hi.getDim1() 
> hi.getDim2() * 2) {
+                       final Hop oldGen = hi.getInput(1);
+                       final Hop y = hi.getInput(0).getInput(1);
+                       final Hop x = hi.getInput(0).getInput(0);
+                       final Hop newGen = HopRewriteUtils.createDataGenOp(y, 
oldGen, 0);
+                       final Hop newCBind = HopRewriteUtils.createBinary(y, 
newGen, OpOp2.CBIND);
+                       final Hop newMM = 
HopRewriteUtils.createMatrixMultiply(x, newCBind);
+
+                       HopRewriteUtils.replaceChildReference(parent, hi, 
newMM, pos);
+                       LOG.debug("Applied MMCBind Zero algebraic 
simplification (line " +hi.getBeginLine()+")." );
+                       return newMM;
+
+               }
+               return hi;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index 31b3714..c865507 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -381,21 +381,21 @@ public class WorkloadAnalyzer {
                                        
transientCompressed.contains(in.get(1).getName());
                                OpSided ret = new OpSided(hop, left, right, 
transposedLeft, transposedRight);
                                if(ret.isRightMM()) {
-                                       HashSet<Long> overlapping2 = new 
HashSet<>();
-                                       overlapping2.add(hop.getHopID());
-                                       WorkloadAnalyzer overlappingAnalysis = 
new WorkloadAnalyzer(prog, overlapping2);
-                                       WTreeRoot r = 
overlappingAnalysis.createWorkloadTree(hop);
-
-                                       CostEstimatorBuilder b = new 
CostEstimatorBuilder(r);
-                                       if(LOG.isTraceEnabled())
-                                               LOG.trace("Workload for 
overlapping: " + r + "\n" + b);
-
-                                       if(b.shouldUseOverlap())
-                                               overlapping.add(hop.getHopID());
-                                       else {
-                                               decompressHops.add(hop);
-                                               
ret.setOverlappingDecompression(true);
-                                       }
+                                       // HashSet<Long> overlapping2 = new 
HashSet<>();
+                                       // overlapping2.add(hop.getHopID());
+                                       // WorkloadAnalyzer overlappingAnalysis 
= new WorkloadAnalyzer(prog, overlapping2);
+                                       // WTreeRoot r = 
overlappingAnalysis.createWorkloadTree(hop);
+
+                                       // CostEstimatorBuilder b = new 
CostEstimatorBuilder(r);
+                                       // if(LOG.isTraceEnabled())
+                                       //      LOG.trace("Workload for 
overlapping: " + r + "\n" + b);
+
+                                       // if(b.shouldUseOverlap())
+                                       overlapping.add(hop.getHopID());
+                                       // else {
+                                       //      decompressHops.add(hop);
+                                       //      
ret.setOverlappingDecompression(true);
+                                       // }
                                }
 
                                return ret;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index 5de8880..af05bdc 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -83,7 +83,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
        @Test
        public void testLmCP() {
-               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, 
false);
        }
 
        @Test
@@ -93,7 +93,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
        @Test
        public void testLmDSCP() {
-               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false);
+               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, 
false);
        }
 
        @Test
@@ -103,41 +103,42 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
 
        @Test
        public void testPCACP() {
-               runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1, false);
+               runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SINGLE_NODE, 1, 
false);
        }
 
        @Test
        public void testSliceLineCP1() {
-               runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 0, false);
+               runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 0, 
false);
        }
 
        @Test
        public void testSliceLineCP2() {
-               runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 2, true);
+               runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 2, 
true);
        }
 
        @Test
        public void testLmCGSP() {
                runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SPARK, 2, false);
        }
-       
+
        @Test
        public void testLmCGCP() {
-               runWorkloadAnalysisTest(TEST_NAME6, ExecMode.HYBRID, 2, false);
+               runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SINGLE_NODE, 2, 
false);
        }
-       
+
        // private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount) {
        private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount, boolean intermediates) {
                ExecMode oldPlatform = setExecMode(mode);
                boolean oldIntermediates = 
WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
-               
+
                try {
                        loadTestConfiguration(getTestConfiguration(testname));
                        WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = 
intermediates;
-                       
+
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] {"-stats", "20", "-args", 
input("X"), input("y"), output("B")};
+                       programArgs = new String[] {"-stats", "20", "-args", 
input("X"), input("y"),
+                               output("B")};
 
                        writeInputMatrixWithMTD("X", X, false);
                        writeInputMatrixWithMTD("y", y, false);
@@ -149,11 +150,12 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
                        long actualCompressionCount = (mode == ExecMode.HYBRID 
|| mode == ExecMode.SINGLE_NODE) ? Statistics
                                .getCPHeavyHitterCount("compress") : 
Statistics.getCPHeavyHitterCount("sp_compress");
 
-                       Assert.assertEquals(compressionCount, 
actualCompressionCount);
-                       if( compressionCount > 0 )
-                               Assert.assertTrue( mode == ExecMode.HYBRID ?
-                                       heavyHittersContainsString("compress") 
: heavyHittersContainsString("sp_compress"));
-                       if( !testname.equals(TEST_NAME4) )
+                       Assert.assertEquals("Assert that the compression counts 
expeted matches actual: " + compressionCount
+                               + " vs " + actualCompressionCount, 
compressionCount, actualCompressionCount);
+                       if(compressionCount > 0)
+                               Assert.assertTrue(mode == ExecMode.SINGLE_NODE 
|| mode == ExecMode.HYBRID ?  heavyHittersContainsString(
+                                       "compress") : 
heavyHittersContainsString("sp_compress"));
+                       if(!testname.equals(TEST_NAME4))
                                
Assert.assertFalse(heavyHittersContainsString("m_scale"));
 
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
new file mode 100644
index 0000000..1cc1cca
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+/**
+ * from:
+ * 
+ * res = cbind((X %*% Y), matrix (0, nrow(X), 1));
+ * 
+ * to:
+ * 
+ * res = X %*% (cbind(Y, matrix(0, nrow(Y), 1)))
+ * 
+ * 
+ * if the X has many rows, the allocation of x is expensive, to cbind. the 
case where this is applicable is mLogReg.
+ * 
+ */
+public class RewriteMMCBindZeroVector extends AutomatedTestBase {
+       // private static final Log LOG = 
LogFactory.getLog(RewriteMMCBindZeroVector.class.getName());
+
+       private static final String TEST_NAME1 = "RewritMMCBindZeroVectorOp";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteMMCBindZeroVector.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+       }
+
+       @Test
+       public void testNoRewritesCP() {
+               testRewrite(TEST_NAME1, false, ExecType.CP, 100, 3, 10);
+       }
+
+       @Test
+       public void testNoRewritesSP() {
+               testRewrite(TEST_NAME1, false, ExecType.SPARK, 100, 3, 10);
+       }
+
+       @Test
+       public void testRewritesCP() {
+               testRewrite(TEST_NAME1, true, ExecType.CP, 100, 3, 10);
+       }
+
+       @Test
+       public void testRewritesSP() {
+               testRewrite(TEST_NAME1, true, ExecType.SPARK, 100, 3, 10);
+       }
+
+       private void testRewrite(String testname, boolean rewrites, ExecType 
et, int leftRows, int rightCols, int shared) {
+               ExecMode platformOld = rtplatform;
+               switch(et) {
+                       case SPARK:
+                               rtplatform = ExecMode.SPARK;
+                               break;
+                       default:
+                               rtplatform = ExecMode.HYBRID;
+                               break;
+               }
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if(rtplatform == ExecMode.SPARK || rtplatform == 
ExecMode.HYBRID)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               boolean rewritesOld = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-args", input("X"), input("Y"),
+                               output("R")};
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       double[][] X = getRandomMatrix(leftRows, shared, -1, 1, 
0.97d, 7);
+                       double[][] Y = getRandomMatrix(shared, rightCols, -1, 
1, 0.9d, 3);
+                       writeInputMatrixWithMTD("X", X, false);
+                       writeInputMatrixWithMTD("Y", Y, false);
+
+                       // execute tests
+                       String out = runTest(null).toString();
+
+                       for(String line : out.split("\n")) {
+                               if(rewrites) {
+                                       if(line.contains("append"))
+                                               break;
+                                       else if(line.contains("ba+*"))
+                                               fail(
+                                                       "invalid execution 
matrix multiplication is done before append, therefore the rewrite did not 
tricker.\n\n"
+                                                               + out);
+                               }
+                               else {
+                                       if(line.contains("ba+*"))
+                                               break;
+                                       else if(line.contains("append"))
+                                               fail(
+                                                       "invalid execution 
append was done before multiplication, therefore the rewrite did tricker when 
not allowed.\n\n"
+                                                               + out);
+                               }
+
+                       }
+                       // compare matrices
+                       // HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewritesOld;
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+}
diff --git 
a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml 
b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
index 12d9dd5..d427506 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
@@ -22,21 +22,20 @@
 X = read($1);
 Y = read($2);
 
-
 print("")
 print("MLogReg")
 
 X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
+B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2, icpt=0);
 
 [nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
-
 [nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
 
+print("Confusion:")
+print(toString(C))
+print("")
 print(acc)
 
-if(acc < 50){
+if(acc < 50)
     stop("MLogReg Accuracy achieved is not high enough")
-}
+
diff --git 
a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml 
b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
similarity index 71%
copy from 
src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
copy to src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
index 12d9dd5..e6b0498 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,24 +19,10 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-Y = read($2);
-
-
-print("")
-print("MLogReg")
-
-X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
-
-[nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
 
-[nn, C] = confusionMatrix(P, Y)
-print("Confusion: ")
-print(toString(C))
+X = read($1)
+Y = read($2)
 
-print(acc)
+res = cbind((X %*% Y), matrix (0, nrow(X), 1));
 
-if(acc < 50){
-    stop("MLogReg Accuracy achieved is not high enough")
-}
+print(sum(res))

Reply via email to