This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e607544ea6ab1f2ec1c9f2c8370c38c86c346170 Author: baunsgaard <[email protected]> AuthorDate: Tue Sep 7 19:14:56 2021 +0200 [SYSTEMDS-3123] Rewrite c bind 0 Matrix Multiplication ``` cbind((X %*% Y), matrix(0, nrow(X), 1)) -> X %*% (cbind(Y, matrix(0, nrow(Y), 1))) ``` This commit contains a rewrite that change the sequences if number of rows in X is 2x larger than Y: This rewrite effects MLogReg in line 215 to not force allocation of the large X twice. --- .../RewriteAlgebraicSimplificationDynamic.java | 54 +++++--- .../compress/workload/WorkloadAnalyzer.java | 30 ++--- .../compress/workload/WorkloadAlgorithmTest.java | 34 ++--- .../rewrite/RewriteMMCBindZeroVector.java | 145 +++++++++++++++++++++ .../compress/workload/WorkloadAnalysisMLogReg.dml | 13 +- .../RewritMMCBindZeroVectorOp.dml} | 26 +--- 6 files changed, 229 insertions(+), 73 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 63a05a4..0b91e5d 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -26,6 +26,20 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import org.apache.sysds.common.Types.AggOp; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.Direction; +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.common.Types.OpOp2; +import org.apache.sysds.common.Types.OpOp3; +import org.apache.sysds.common.Types.OpOp4; +import org.apache.sysds.common.Types.OpOpDG; +import org.apache.sysds.common.Types.OpOpN; +import org.apache.sysds.common.Types.ParamBuiltinOp; +import org.apache.sysds.common.Types.ReOrgOp; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; import org.apache.sysds.hops.BinaryOp; @@ -41,22 +55,8 @@ import org.apache.sysds.hops.QuaternaryOp; import org.apache.sysds.hops.ReorgOp; import org.apache.sysds.hops.TernaryOp; import org.apache.sysds.hops.UnaryOp; -import org.apache.sysds.common.Types.AggOp; -import org.apache.sysds.common.Types.Direction; -import org.apache.sysds.common.Types.OpOp1; -import org.apache.sysds.common.Types.OpOp2; -import org.apache.sysds.common.Types.OpOp3; -import org.apache.sysds.common.Types.OpOp4; -import org.apache.sysds.common.Types.OpOpDG; -import org.apache.sysds.common.Types.OpOpN; -import org.apache.sysds.common.Types.ParamBuiltinOp; -import org.apache.sysds.common.Types.ReOrgOp; import org.apache.sysds.lops.MapMultChain.ChainType; import org.apache.sysds.parser.DataExpression; -import org.apache.sysds.common.Types.DataType; -import org.apache.sysds.common.Types.ValueType; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; /** * Rule: Algebraic Simplifications. Simplifies binary expressions @@ -109,7 +109,6 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { if( root == null ) return root; - //one pass rewrite-descend (rewrite created pattern) rule_AlgebraicSimplification( root, false ); @@ -197,6 +196,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known hi = simplifyNrowNcolComputation(hop, hi, i); //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) + hi = simplyfyMMCBindZeroVector(hop, hi, i); //e.g.. cbind((X %*% Y), matrix (0, nrow(X), 1)) -> X %*% (cbind(Y, matrix(0, nrow(Y), 1))) if nRows of x is larger than nCols of y if( OptimizerUtils.ALLOW_OPERATOR_FUSION ) foldMultipleMinMaxOperations(hi); //e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y) @@ -2796,4 +2796,28 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule return hi; } + + private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) { + + // cbind((X %*% Y), matrix(0, nrow(X), 1)) -> + // X %*% (cbind(Y, matrix(0, nrow(Y), 1))) + // if nRows of x is larger than nCols of y + // rewrite used in MLogReg first level loop. + + if(HopRewriteUtils.isBinary(hi, OpOp2.CBIND) && HopRewriteUtils.isMatrixMultiply(hi.getInput(0)) && + HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(1), 0) && hi.getDim1() > hi.getDim2() * 2) { + final Hop oldGen = hi.getInput(1); + final Hop y = hi.getInput(0).getInput(1); + final Hop x = hi.getInput(0).getInput(0); + final Hop newGen = HopRewriteUtils.createDataGenOp(y, oldGen, 0); + final Hop newCBind = HopRewriteUtils.createBinary(y, newGen, OpOp2.CBIND); + final Hop newMM = HopRewriteUtils.createMatrixMultiply(x, newCBind); + + HopRewriteUtils.replaceChildReference(parent, hi, newMM, pos); + LOG.debug("Applied MMCBind Zero algebraic simplification (line " +hi.getBeginLine()+")." ); + return newMM; + + } + return hi; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java index 31b3714..c865507 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -381,21 +381,21 @@ public class WorkloadAnalyzer { transientCompressed.contains(in.get(1).getName()); OpSided ret = new OpSided(hop, left, right, transposedLeft, transposedRight); if(ret.isRightMM()) { - HashSet<Long> overlapping2 = new HashSet<>(); - overlapping2.add(hop.getHopID()); - WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2); - WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop); - - CostEstimatorBuilder b = new CostEstimatorBuilder(r); - if(LOG.isTraceEnabled()) - LOG.trace("Workload for overlapping: " + r + "\n" + b); - - if(b.shouldUseOverlap()) - overlapping.add(hop.getHopID()); - else { - decompressHops.add(hop); - ret.setOverlappingDecompression(true); - } + // HashSet<Long> overlapping2 = new HashSet<>(); + // overlapping2.add(hop.getHopID()); + // WorkloadAnalyzer overlappingAnalysis = new WorkloadAnalyzer(prog, overlapping2); + // WTreeRoot r = overlappingAnalysis.createWorkloadTree(hop); + + // CostEstimatorBuilder b = new CostEstimatorBuilder(r); + // if(LOG.isTraceEnabled()) + // LOG.trace("Workload for overlapping: " + r + "\n" + b); + + // if(b.shouldUseOverlap()) + overlapping.add(hop.getHopID()); + // else { + // decompressHops.add(hop); + // ret.setOverlappingDecompression(true); + // } } return ret; diff --git a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java index 5de8880..af05bdc 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java @@ -83,7 +83,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase { @Test public void testLmCP() { - runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false); + runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false); } @Test @@ -93,7 +93,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase { @Test public void testLmDSCP() { - runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2, false); + runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SINGLE_NODE, 2, false); } @Test @@ -103,41 +103,42 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase { @Test public void testPCACP() { - runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1, false); + runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SINGLE_NODE, 1, false); } @Test public void testSliceLineCP1() { - runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 0, false); + runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 0, false); } @Test public void testSliceLineCP2() { - runWorkloadAnalysisTest(TEST_NAME4, ExecMode.HYBRID, 2, true); + runWorkloadAnalysisTest(TEST_NAME4, ExecMode.SINGLE_NODE, 2, true); } @Test public void testLmCGSP() { runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SPARK, 2, false); } - + @Test public void testLmCGCP() { - runWorkloadAnalysisTest(TEST_NAME6, ExecMode.HYBRID, 2, false); + runWorkloadAnalysisTest(TEST_NAME6, ExecMode.SINGLE_NODE, 2, false); } - + // private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount) { private void runWorkloadAnalysisTest(String testname, ExecMode mode, int compressionCount, boolean intermediates) { ExecMode oldPlatform = setExecMode(mode); boolean oldIntermediates = WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES; - + try { loadTestConfiguration(getTestConfiguration(testname)); WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES = intermediates; - + String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), output("B")}; + programArgs = new String[] {"-stats", "20", "-args", input("X"), input("y"), + output("B")}; writeInputMatrixWithMTD("X", X, false); writeInputMatrixWithMTD("y", y, false); @@ -149,11 +150,12 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase { long actualCompressionCount = (mode == ExecMode.HYBRID || mode == ExecMode.SINGLE_NODE) ? Statistics .getCPHeavyHitterCount("compress") : Statistics.getCPHeavyHitterCount("sp_compress"); - Assert.assertEquals(compressionCount, actualCompressionCount); - if( compressionCount > 0 ) - Assert.assertTrue( mode == ExecMode.HYBRID ? - heavyHittersContainsString("compress") : heavyHittersContainsString("sp_compress")); - if( !testname.equals(TEST_NAME4) ) + Assert.assertEquals("Assert that the compression counts expeted matches actual: " + compressionCount + + " vs " + actualCompressionCount, compressionCount, actualCompressionCount); + if(compressionCount > 0) + Assert.assertTrue(mode == ExecMode.SINGLE_NODE || mode == ExecMode.HYBRID ? heavyHittersContainsString( + "compress") : heavyHittersContainsString("sp_compress")); + if(!testname.equals(TEST_NAME4)) Assert.assertFalse(heavyHittersContainsString("m_scale")); } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java new file mode 100644 index 0000000..1cc1cca --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMMCBindZeroVector.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import static org.junit.Assert.fail; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +/** + * from: + * + * res = cbind((X %*% Y), matrix (0, nrow(X), 1)); + * + * to: + * + * res = X %*% (cbind(Y, matrix(0, nrow(Y), 1))) + * + * + * if the X has many rows, the allocation of x is expensive, to cbind. the case where this is applicable is mLogReg. + * + */ +public class RewriteMMCBindZeroVector extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(RewriteMMCBindZeroVector.class.getName()); + + private static final String TEST_NAME1 = "RewritMMCBindZeroVectorOp"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteMMCBindZeroVector.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"})); + } + + @Test + public void testNoRewritesCP() { + testRewrite(TEST_NAME1, false, ExecType.CP, 100, 3, 10); + } + + @Test + public void testNoRewritesSP() { + testRewrite(TEST_NAME1, false, ExecType.SPARK, 100, 3, 10); + } + + @Test + public void testRewritesCP() { + testRewrite(TEST_NAME1, true, ExecType.CP, 100, 3, 10); + } + + @Test + public void testRewritesSP() { + testRewrite(TEST_NAME1, true, ExecType.SPARK, 100, 3, 10); + } + + private void testRewrite(String testname, boolean rewrites, ExecType et, int leftRows, int rightCols, int shared) { + ExecMode platformOld = rtplatform; + switch(et) { + case SPARK: + rtplatform = ExecMode.SPARK; + break; + default: + rtplatform = ExecMode.HYBRID; + break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-args", input("X"), input("Y"), + output("R")}; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double[][] X = getRandomMatrix(leftRows, shared, -1, 1, 0.97d, 7); + double[][] Y = getRandomMatrix(shared, rightCols, -1, 1, 0.9d, 3); + writeInputMatrixWithMTD("X", X, false); + writeInputMatrixWithMTD("Y", Y, false); + + // execute tests + String out = runTest(null).toString(); + + for(String line : out.split("\n")) { + if(rewrites) { + if(line.contains("append")) + break; + else if(line.contains("ba+*")) + fail( + "invalid execution matrix multiplication is done before append, therefore the rewrite did not tricker.\n\n" + + out); + } + else { + if(line.contains("ba+*")) + break; + else if(line.contains("append")) + fail( + "invalid execution append was done before multiplication, therefore the rewrite did tricker when not allowed.\n\n" + + out); + } + + } + // compare matrices + // HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml index 12d9dd5..d427506 100644 --- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml +++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml @@ -22,21 +22,20 @@ X = read($1); Y = read($2); - print("") print("MLogReg") X = scale(X=X, scale=TRUE, center=TRUE); -B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2); +B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2, icpt=0); [nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y) - [nn, C] = confusionMatrix(P, Y) -print("Confusion: ") -print(toString(C)) +print("Confusion:") +print(toString(C)) +print("") print(acc) -if(acc < 50){ +if(acc < 50) stop("MLogReg Accuracy achieved is not high enough") -} + diff --git a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml similarity index 71% copy from src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml copy to src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml index 12d9dd5..e6b0498 100644 --- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml +++ b/src/test/scripts/functions/rewrite/RewritMMCBindZeroVectorOp.dml @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,24 +19,10 @@ # #------------------------------------------------------------- -X = read($1); -Y = read($2); - - -print("") -print("MLogReg") - -X = scale(X=X, scale=TRUE, center=TRUE); -B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2); - -[nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y) -[nn, C] = confusionMatrix(P, Y) -print("Confusion: ") -print(toString(C)) +X = read($1) +Y = read($2) -print(acc) +res = cbind((X %*% Y), matrix (0, nrow(X), 1)); -if(acc < 50){ - stop("MLogReg Accuracy achieved is not high enough") -} +print(sum(res))
