This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 00f649129a [SYSTEMDS-3470] Lineage-based local reuse of Spark actions
00f649129a is described below

commit 00f649129ae1568682ee7e0121e7aa0d75353271
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Nov 28 00:02:51 2022 +0100

    [SYSTEMDS-3470] Lineage-based local reuse of Spark actions
    
    This patch extends the lineage cache framework to cache the results
    of the Spark instruction, which return intermediates back to local.
    Only caching the actions avoids unnecessary triggering and fetching
    all Spark intermediates. For now, we avoid caching the future-based
    results.
    
    Closes #1739
---
 .../spark/AggregateUnarySPInstruction.java         |   4 +
 .../instructions/spark/CpmmSPInstruction.java      |   6 +-
 .../instructions/spark/MapmmSPInstruction.java     |   8 +-
 .../apache/sysds/runtime/lineage/LineageCache.java |  45 ++++++--
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  45 +++++++-
 .../functions/async/LineageReuseSparkTest.java     | 116 +++++++++++++++++++++
 .../functions/rewrite/RewriteListTsmmCVTest.java   |   3 +-
 .../scripts/functions/async/LineageReuseSpark1.dml |  53 ++++++++++
 8 files changed, 264 insertions(+), 16 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 89f385c54e..48d41fd602 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -222,6 +222,10 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                }
        }
 
+       public SparkAggType getAggType() {
+               return _aggtype;
+       }
+
        private static class RDDUAggFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
        {
                private static final long serialVersionUID = 
2672082409287856038L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index f9b3130ccc..0cbd4acfe0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -148,7 +148,11 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                        }
                }
        }
-       
+
+       public SparkAggType getAggType() {
+               return _aggtype;
+       }
+
        private static int getPreferredParJoin(DataCharacteristics mc1, 
DataCharacteristics mc2, int numPar1, int numPar2) {
                int defPar = SparkExecutionContext.getDefaultParallelism(true);
                int maxParIn = Math.max(numPar1, numPar2);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index f50c46aceb..3a1a6c27d9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -126,8 +126,8 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                }
                
                //get inputs
-               PartitionedBroadcast<MatrixBlock> in2 = 
sec.getBroadcastForVariable(bcastVar); 
-               
+               PartitionedBroadcast<MatrixBlock> in2 = 
sec.getBroadcastForVariable(bcastVar);
+
                //empty input block filter
                if( !_outputEmpty )
                        in1 = in1.filter(new FilterNonEmptyBlocksFunction());
@@ -176,6 +176,10 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                }
        }
 
+       public SparkAggType getAggType() {
+               return _aggtype;
+       }
+
        private static boolean preservesPartitioning(DataCharacteristics mcIn, 
CacheType type )
        {
                if( type == CacheType.LEFT )
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index c1135cdb54..4c79b885b7 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -32,6 +32,7 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.MatrixObjectFuture;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -48,6 +49,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -91,11 +93,14 @@ public class LineageCache
                // will always fit in memory and hence can be pinned 
unconditionally
                if (LineageCacheConfig.isReusable(inst, ec)) {
                        ComputationCPInstruction cinst = inst instanceof 
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
-                       ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null; 
+                       ComputationFEDInstruction cfinst = inst instanceof 
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+                       ComputationSPInstruction cspinst = inst instanceof 
ComputationSPInstruction ? (ComputationSPInstruction)inst : null;
                        GPUInstruction gpuinst = inst instanceof GPUInstruction 
? (GPUInstruction)inst : null;
+                       //TODO: Replace with generic type
                                
                        LineageItem instLI = (cinst != null) ? 
cinst.getLineageItem(ec).getValue()
-                                       : (cfinst != null) ? 
cfinst.getLineageItem(ec).getValue() 
+                                       : (cfinst != null) ? 
cfinst.getLineageItem(ec).getValue()
+                                       : (cspinst != null) ? 
cspinst.getLineageItem(ec).getValue()
                                        : gpuinst.getLineageItem(ec).getValue();
                        List<MutablePair<LineageItem, LineageCacheEntry>> 
liList = null;
                        if (inst instanceof MultiReturnBuiltinCPInstruction) {
@@ -120,10 +125,10 @@ public class LineageCache
                                                e = 
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
                                        //TODO need to also move execution of 
compensation plan out of here
                                        //(create lazily evaluated entry)
-                                       if (e == null && 
LineageCacheConfig.getCacheType().isPartialReuse())
+                                       if (e == null && 
LineageCacheConfig.getCacheType().isPartialReuse() && cspinst == null)
                                                if( 
LineageRewriteReuse.executeRewrites(inst, ec) )
                                                        e = 
getIntern(item.getKey());
-                                       //TODO: MultiReturnBuiltin and partial 
rewrites
+                                       //TODO: Partial reuse for Spark 
instructions
                                        reuseAll &= (e != null);
                                        item.setValue(e);
                                        
@@ -134,6 +139,8 @@ public class LineageCache
                                                        
putIntern(item.getKey(), cinst.output.getDataType(), null, null,  0);
                                                else if (cfinst != null)
                                                        
putIntern(item.getKey(), cfinst.output.getDataType(), null, null,  0);
+                                               else if (cspinst != null)
+                                                       
putIntern(item.getKey(), cspinst.output.getDataType(), null, null,  0);
                                                else if (gpuinst != null)
                                                        
putIntern(item.getKey(), gpuinst._output.getDataType(), null, null,  0);
                                                //FIXME: different o/p 
datatypes for MultiReturnBuiltins.
@@ -155,6 +162,8 @@ public class LineageCache
                                                outName = 
cinst.output.getName();
                                        else if (inst instanceof 
ComputationFEDInstruction)
                                                outName = 
cfinst.output.getName();
+                                       else if (inst instanceof 
ComputationSPInstruction)
+                                               outName = 
cspinst.output.getName();
                                        else if (inst instanceof GPUInstruction)
                                                outName = 
gpuinst._output.getName();
                                        
@@ -483,9 +492,14 @@ public class LineageCache
                if (LineageCacheConfig.isReusable(inst, ec) ) {
                        LineageItem item = ((LineageTraceable) 
inst).getLineageItem(ec).getValue();
                        //This method is called only to put matrix value
-                       MatrixObject mo = inst instanceof 
ComputationCPInstruction ? 
-                                       
ec.getMatrixObject(((ComputationCPInstruction) inst).output) :
-                                       
ec.getMatrixObject(((ComputationFEDInstruction) inst).output);
+                       MatrixObject mo = null;
+                       if (inst instanceof ComputationCPInstruction)
+                               mo = 
ec.getMatrixObject(((ComputationCPInstruction) inst).output);
+                       else if (inst instanceof ComputationFEDInstruction)
+                               mo = 
ec.getMatrixObject(((ComputationFEDInstruction) inst).output);
+                       else if (inst instanceof ComputationSPInstruction)
+                               mo = 
ec.getMatrixObject(((ComputationSPInstruction) inst).output);
+
                        synchronized( _cache ) {
                                putIntern(item, DataType.MATRIX, 
mo.acquireReadAndRelease(), null, computetime);
                        }
@@ -527,9 +541,12 @@ public class LineageCache
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction)inst)._output)));
                        }
                        else
-                               liData = inst instanceof 
ComputationCPInstruction ? 
-                                               Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationCPInstruction) inst).output))) :
-                                               Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationFEDInstruction) inst).output)));
+                               if (inst instanceof ComputationCPInstruction)
+                                       liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationCPInstruction) inst).output)));
+                               else if (inst instanceof 
ComputationFEDInstruction)
+                                       liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationFEDInstruction) inst).output)));
+                               else if (inst instanceof 
ComputationSPInstruction)
+                                       liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationSPInstruction) inst).output)));
 
                        if (liGpuObj == null)
                                putValueCPU(inst, liData, computetime);
@@ -556,6 +573,12 @@ public class LineageCache
                                        continue;
                                }
 
+                               if (data instanceof MatrixObjectFuture) {
+                                       // We don't want to call get() on the 
future immediately after the execution
+                                       removePlaceholder(item);
+                                       continue;
+                               }
+
                                if (LineageCacheConfig.isOutputFederated(inst, 
data)) {
                                        // Do not cache federated outputs (in 
the coordinator)
                                        // Cannot skip putting the placeholder 
as the above is only known after execution
@@ -867,10 +890,12 @@ public class LineageCache
                
                CPOperand output = inst instanceof ComputationCPInstruction ? 
((ComputationCPInstruction)inst).output 
                                : inst instanceof ComputationFEDInstruction ? 
((ComputationFEDInstruction)inst).output
+                               : inst instanceof ComputationSPInstruction ? 
((ComputationSPInstruction)inst).output
                                : ((GPUInstruction)inst)._output;
                if (output.isMatrix()) {
                        MatrixObject mo = inst instanceof 
ComputationCPInstruction ? 
ec.getMatrixObject(((ComputationCPInstruction)inst).output) 
                                : inst instanceof ComputationFEDInstruction ? 
ec.getMatrixObject(((ComputationFEDInstruction)inst).output)
+                               : inst instanceof ComputationSPInstruction ? 
ec.getMatrixObject(((ComputationSPInstruction)inst).output)
                                : 
ec.getMatrixObject(((GPUInstruction)inst)._output);
                        //limit this to full reuse as partial reuse is 
applicable even for loop dependent operation
                        return !(LineageCacheConfig.getCacheType() == 
ReuseCacheType.REUSE_FULL  
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index e2027f467a..4259481444 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -21,8 +21,10 @@ package org.apache.sysds.runtime.lineage;
 
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
@@ -33,6 +35,11 @@ import 
org.apache.sysds.runtime.instructions.cp.ListIndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TsmmSPInstruction;
 
 import java.util.Comparator;
 
@@ -46,7 +53,8 @@ public class LineageCacheConfig
                "uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", 
"<=",
                "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", 
"replace",
                "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", 
"qsort", 
-               "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri"
+               "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri",
+               "mapmm", "cpmm"
                //TODO: Reuse everything. 
        };
        private static String[] REUSE_OPCODES  = new String[] {};
@@ -197,7 +205,8 @@ public class LineageCacheConfig
        public static boolean isReusable (Instruction inst, ExecutionContext 
ec) {
                boolean insttype = (inst instanceof ComputationCPInstruction 
                        || inst instanceof ComputationFEDInstruction
-                       || inst instanceof GPUInstruction)
+                       || inst instanceof GPUInstruction
+                       || (inst instanceof ComputationSPInstruction && 
isRightSparkOp(inst)))
                        && !(inst instanceof ListIndexingCPInstruction);
                boolean rightop = (ArrayUtils.contains(REUSE_OPCODES, 
inst.getOpcode())
                        || (inst.getOpcode().equals("append") && 
isVectorAppend(inst, ec))
@@ -226,6 +235,14 @@ public class LineageCacheConfig
                        long c2 = 
ec.getMatrixObject(cpinst.input2).getNumColumns();
                        return(c1 == 1 || c2 == 1);
                }
+               if (inst instanceof ComputationSPInstruction) {
+                       ComputationSPInstruction fedinst = 
(ComputationSPInstruction) inst;
+                       if (!fedinst.input1.isMatrix() || 
!fedinst.input2.isMatrix())
+                               return false;
+                       long c1 = 
ec.getMatrixObject(fedinst.input1).getNumColumns();
+                       long c2 = 
ec.getMatrixObject(fedinst.input2).getNumColumns();
+                       return(c1 == 1 || c2 == 1);
+               }
                else { //GPUInstruction
                        GPUInstruction gpuinst = (GPUInstruction)inst;
                        if( !gpuinst._input1.isMatrix() || 
!gpuinst._input2.isMatrix() )
@@ -235,6 +252,30 @@ public class LineageCacheConfig
                        return(c1 == 1 || c2 == 1);
                }
        }
+
+       // Check if the Spark instruction returns result back to local
+       private static boolean isRightSparkOp(Instruction inst) {
+               if (!(inst instanceof ComputationSPInstruction))
+                       return false;
+
+               boolean spAction = false;
+               if (inst instanceof MapmmSPInstruction &&
+                       ((MapmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
+                       spAction = true;
+               else if (inst instanceof TsmmSPInstruction)
+                       spAction = true;
+               else if (inst instanceof AggregateUnarySPInstruction &&
+                       ((AggregateUnarySPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
+                       spAction = true;
+               else if (inst instanceof CpmmSPInstruction &&
+                       ((CpmmSPInstruction) inst).getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK)
+                       spAction = true;
+               else if (((ComputationSPInstruction) inst).output.getDataType() 
== Types.DataType.SCALAR)
+                       spAction = true;
+               //TODO: include other cases
+
+               return spAction;
+       }
        
        public static boolean isOutputFederated(Instruction inst, Data data) {
                if (!(inst instanceof ComputationFEDInstruction))
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
new file mode 100644
index 0000000000..83992c5772
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.async;
+
+       import java.util.ArrayList;
+       import java.util.HashMap;
+       import java.util.List;
+
+       import org.apache.sysds.common.Types.ExecMode;
+       import org.apache.sysds.hops.OptimizerUtils;
+       import org.apache.sysds.hops.recompile.Recompiler;
+       import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+       import org.apache.sysds.runtime.lineage.Lineage;
+       import org.apache.sysds.runtime.matrix.data.MatrixValue;
+       import org.apache.sysds.test.AutomatedTestBase;
+       import org.apache.sysds.test.TestConfiguration;
+       import org.apache.sysds.test.TestUtils;
+       import org.apache.sysds.utils.Statistics;
+       import org.junit.Assert;
+       import org.junit.Test;
+
+public class LineageReuseSparkTest extends AutomatedTestBase {
+
+       protected static final String TEST_DIR = "functions/async/";
+       protected static final String TEST_NAME = "LineageReuseSpark";
+       protected static final int TEST_VARIANTS = 1;
+       protected static String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseSparkTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               for(int i=1; i<=TEST_VARIANTS; i++)
+                       addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
+       }
+
+       @Test
+       public void testlmds() {
+               runTest(TEST_NAME+"1");
+       }
+
+       public void runTest(String testname) {
+               boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+               boolean old_trans_exec_type = 
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
+               ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
+
+               long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
+               long mem = 1024*1024*8;
+               InfrastructureAnalyzer.setLocalMaxMemory(mem);
+
+               try {
+                       getAndLoadTestConfiguration(testname);
+                       fullDMLScriptName = getScript();
+
+                       List<String> proArgs = new ArrayList<>();
+
+                       proArgs.add("-explain");
+                       proArgs.add("-stats");
+                       proArgs.add("-args");
+                       proArgs.add(output("R"));
+                       programArgs = proArgs.toArray(new 
String[proArgs.size()]);
+
+                       Lineage.resetInternalState();
+                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
+                       long numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+
+                       proArgs.add("-explain");
+                       proArgs.add("-stats");
+                       proArgs.add("-lineage");
+                       proArgs.add("reuse_hybrid");
+                       proArgs.add("-args");
+                       proArgs.add(output("R"));
+                       programArgs = proArgs.toArray(new 
String[proArgs.size()]);
+
+                       Lineage.resetInternalState();
+                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLScalarFromOutputDir("R");
+                       long numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm_r= 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+
+                       //compare matrices
+                       boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
+                       if (!matchVal)
+                               System.out.println("Value w/o reuse "+R+" w/ 
reuse "+R_reused);
+                       Assert.assertTrue("Violated sp_tsmm: reuse count: 
"+numTsmm_r+" < "+numTsmm, numTsmm_r < numTsmm);
+                       Assert.assertTrue("Violated sp_mapmm: reuse count: 
"+numMapmm_r+" < "+numMapmm, numMapmm_r < numMapmm);
+
+               } finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
+                       OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = 
old_trans_exec_type;
+                       resetExecMode(oldPlatform);
+                       InfrastructureAnalyzer.setLocalMaxMemory(oldmem);
+                       Recompiler.reinitRecompiler();
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteListTsmmCVTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteListTsmmCVTest.java
index eef00bc5e3..7f0c986368 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteListTsmmCVTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteListTsmmCVTest.java
@@ -128,7 +128,8 @@ public class RewriteListTsmmCVTest extends AutomatedTestBase
                        if( instType == ExecType.CP )
                                Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                        if( rewrites ) {
-                               boolean expectedReuse = lineage && instType == 
ExecType.CP;
+                               //boolean expectedReuse = lineage && instType 
== ExecType.CP;
+                               boolean expectedReuse = lineage;
                                String[] codes = (instType==ExecType.CP) ?
                                        new 
String[]{"rbind","tsmm","ba+*","n+"} :
                                        new 
String[]{"sp_append","sp_tsmm","sp_mapmm","sp_n+"};
diff --git a/src/test/scripts/functions/async/LineageReuseSpark1.dml 
b/src/test/scripts/functions/async/LineageReuseSpark1.dml
new file mode 100644
index 0000000000..f4675f69b9
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark1.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N) return (Matrix[double] beta)
+{
+  # Reuse sp_tsmm and sp_mapmm if not future-based
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=10000, cols=200, seed=42);
+y = rand(rows=10000, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  #beta = lmDS(X=X, y=y, reg=lamda);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+

Reply via email to