This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 912908316a [SYSTEMDS-3474] Lineage-based reuse of prefetch instruction
912908316a is described below

commit 912908316a29dbbecfd89c121117cfc32f740a2a
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Dec 2 20:27:45 2022 +0100

    [SYSTEMDS-3474] Lineage-based reuse of prefetch instruction
    
    This patch enables caching and reusing prefetch instruction
    outputs. This is the first step towards reusing asynchronous
    operators.
    
    Closes #1746
---
 .../lops/compile/linearization/ILinearize.java     |  3 +-
 .../instructions/cp/PrefetchCPInstruction.java     |  7 ++-
 .../instructions/cp/TriggerPrefetchTask.java       | 15 ++++++
 .../apache/sysds/runtime/lineage/LineageCache.java | 52 ++++++++++++++++++
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  2 +-
 .../functions/async/LineageReuseSparkTest.java     | 63 +++++++++++++++++-----
 .../scripts/functions/async/LineageReuseSpark2.dml | 40 ++++++++++++++
 7 files changed, 166 insertions(+), 16 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 7eee970e2b..70ab1533df 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -44,6 +44,7 @@ import org.apache.sysds.lops.CSVReBlock;
 import org.apache.sysds.lops.CentralMoment;
 import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.DataGen;
 import org.apache.sysds.lops.GroupedAggregate;
 import org.apache.sysds.lops.GroupedAggregateM;
 import org.apache.sysds.lops.Lop;
@@ -359,7 +360,7 @@ public interface ILinearize {
                                && !(lop instanceof CoVariance)
                                // Not qualified for prefetching
                                && !(lop instanceof Checkpoint) && !(lop 
instanceof ReBlock)
-                               && !(lop instanceof CSVReBlock)
+                               && !(lop instanceof CSVReBlock) && !(lop 
instanceof DataGen)
                                // Cannot filter Transformation cases from 
Actions (FIXME)
                                && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
                                && !(lop instanceof ParameterizedBuiltin) && 
!(lop instanceof SpoofFused);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index db9bbb1b84..192e165391 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -23,6 +23,8 @@ import java.util.concurrent.Executors;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
@@ -42,8 +44,9 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               //TODO: handle non-matrix objects
+               // TODO: handle non-matrix objects
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
+               LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? 
this.getLineageItem(ec).getValue() : null;
 
                // Note, a Prefetch instruction doesn't guarantee an 
asynchronous execution.
                // If the next instruction which takes this output as an input 
comes before
@@ -51,6 +54,6 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // In that case this Prefetch instruction will act like a NOOP. 
                if (CommonThreadPool.triggerRemoteOPsPool == null)
                        CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output)));
+               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output), li));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 26a1c8e8bb..b7c69d01f5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -22,18 +22,28 @@ package org.apache.sysds.runtime.instructions.cp;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
+import org.apache.sysds.runtime.lineage.LineageCache;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.utils.stats.SparkStatistics;
 
 public class TriggerPrefetchTask implements Runnable {
        MatrixObject _prefetchMO;
+       LineageItem _inputLi;
 
        public TriggerPrefetchTask(MatrixObject mo) {
                _prefetchMO = mo;
+               _inputLi = null;
+       }
+
+       public TriggerPrefetchTask(MatrixObject mo, LineageItem li) {
+               _prefetchMO = mo;
+               _inputLi = li;
        }
 
        @Override
        public void run() {
                boolean prefetched = false;
+               long t1 = System.nanoTime();
                synchronized (_prefetchMO) {
                        // Having this check inside the critical section
                        // safeguards against concurrent rmVar.
@@ -44,6 +54,11 @@ public class TriggerPrefetchTask implements Runnable {
                                prefetched = true;
                        }
                }
+
+               // Save the collected intermediate in the lineage cache
+               if (_inputLi != null)
+                       LineageCache.putValueAsyncOp(_inputLi, _prefetchMO, 
prefetched, t1);
+
                if (DMLScript.STATISTICS && prefetched) {
                        if (_prefetchMO.isFederated())
                                FederatedStatistics.incAsyncPrefetchCount(1);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index ecd734c588..8e8d962199 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -38,6 +38,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.instructions.CPInstructionParser;
 import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.BroadcastCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
@@ -45,6 +46,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.PrefetchCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
@@ -579,6 +581,10 @@ public class LineageCache
                                        continue;
                                }
 
+                               if (inst instanceof PrefetchCPInstruction || 
inst instanceof BroadcastCPInstruction)
+                                       // For the async. instructions, caching 
is handled separately by the tasks
+                                       continue;
+
                                if (data instanceof MatrixObject && 
((MatrixObject) data).hasRDDHandle()) {
                                        // Avoid triggering pre-matured Spark 
instruction chains
                                        removePlaceholder(item);
@@ -637,6 +643,52 @@ public class LineageCache
                }
        }
 
+       public static void putValueAsyncOp(LineageItem instLI, Data data, 
boolean prefetched, long starttime)
+       {
+               if (ReuseCacheType.isNone())
+                       return;
+               if (!prefetched) //prefetching was not successful
+                       return;
+
+               synchronized( _cache )
+               {
+                       if (!probe(instLI))
+                               return;
+
+                       long computetime = System.nanoTime() - starttime;
+                       LineageCacheEntry centry = _cache.get(instLI);
+                       if(!(data instanceof MatrixObject) && !(data instanceof 
ScalarObject)) {
+                               // Reusable instructions can return a frame 
(rightIndex). Remove placeholders.
+                               removePlaceholder(instLI);
+                               return;
+                       }
+
+                       MatrixBlock mb = (data instanceof MatrixObject) ?
+                               ((MatrixObject)data).acquireReadAndRelease() : 
null;
+                       long size = mb != null ? mb.getInMemorySize() : 
((ScalarObject)data).getSize();
+
+                       // remove the placeholder if the entry is bigger than 
the cache.
+                       if (size > LineageCacheEviction.getCacheLimit()) {
+                               removePlaceholder(instLI);
+                               return;
+                       }
+
+                       // place the data
+                       if (data instanceof MatrixObject)
+                               centry.setValue(mb, computetime);
+                       else if (data instanceof ScalarObject)
+                               centry.setValue((ScalarObject)data, 
computetime);
+
+                       if (DMLScript.STATISTICS && 
LineageCacheEviction._removelist.containsKey(centry._key)) {
+                               // Add to missed compute time
+                               
LineageCacheStatistics.incrementMissedComputeTime(centry._computeTime);
+                       }
+
+                       //maintain order for eviction
+                       LineageCacheEviction.addEntry(centry);
+               }
+       }
+
        public static void putValue(List<DataIdentifier> outputs,
                LineageItem[] liInputs, String name, ExecutionContext ec, long 
computetime)
        {
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 5a7c46dfe7..72ea3835a2 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -54,7 +54,7 @@ public class LineageCacheConfig
                "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", 
"replace",
                "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", 
"qsort", 
                "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri",
-               "mapmm", "cpmm"
+               "mapmm", "cpmm", "prefetch"
                //TODO: Reuse everything. 
        };
        private static String[] REUSE_OPCODES  = new String[] {};
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index e958093122..57d7892b3b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -28,6 +28,7 @@ package org.apache.sysds.test.functions.async;
        import org.apache.sysds.hops.recompile.Recompiler;
        import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
        import org.apache.sysds.runtime.lineage.Lineage;
+       import org.apache.sysds.runtime.lineage.LineageCacheConfig;
        import org.apache.sysds.runtime.matrix.data.MatrixValue;
        import org.apache.sysds.test.AutomatedTestBase;
        import org.apache.sysds.test.TestConfiguration;
@@ -40,7 +41,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 
        protected static final String TEST_DIR = "functions/async/";
        protected static final String TEST_NAME = "LineageReuseSpark";
-       protected static final int TEST_VARIANTS = 1;
+       protected static final int TEST_VARIANTS = 2;
        protected static String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseSparkTest.class.getSimpleName() + "/";
 
        @Override
@@ -52,16 +53,21 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
 
        @Test
        public void testlmdsHB() {
-               runTest(TEST_NAME+"1", ExecMode.HYBRID);
+               runTest(TEST_NAME+"1", ExecMode.HYBRID, 1);
        }
 
        @Test
        public void testlmdsSP() {
                // Only reuse the actions
-               runTest(TEST_NAME+"1", ExecMode.SPARK);
+               runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
        }
 
-       public void runTest(String testname, ExecMode execMode) {
+       @Test
+       public void testReusePrefetch() {
+               runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
+       }
+
+       public void runTest(String testname, ExecMode execMode, int testId) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
                boolean old_trans_exec_type = 
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
@@ -85,31 +91,52 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
+                       if (testId == 2) enablePrefetch();
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+                       disablePrefetch();
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                       long numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
-
+                       long numTsmm = 0;
+                       long numMapmm = 0;
+                       if (testId == 1) {
+                               numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                               numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       }
+                       long numPrefetch = 0;
+                       if (testId == 2) numPrefetch = 
Statistics.getCPHeavyHitterCount("prefetch");
+
+                       proArgs.clear();
                        proArgs.add("-explain");
                        proArgs.add("-stats");
                        proArgs.add("-lineage");
-                       proArgs.add("reuse_hybrid");
+                       
proArgs.add(LineageCacheConfig.ReuseCacheType.REUSE_FULL.name().toLowerCase());
                        proArgs.add("-args");
                        proArgs.add(output("R"));
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
+                       if (testId == 2) enablePrefetch();
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+                       disablePrefetch();
                        HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                       long numMapmm_r= 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       long numTsmm_r = 0;
+                       long numMapmm_r = 0;
+                       if (testId == 1) {
+                               numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                               numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       }
+                       long numPrefetch_r = 0;
+                       if (testId == 2) numPrefetch_r = 
Statistics.getCPHeavyHitterCount("prefetch");
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
                        if (!matchVal)
                                System.out.println("Value w/o reuse "+R+" w/ 
reuse "+R_reused);
-                       Assert.assertTrue("Violated sp_tsmm: reuse count: 
"+numTsmm_r+" < "+numTsmm, numTsmm_r < numTsmm);
-                       Assert.assertTrue("Violated sp_mapmm: reuse count: 
"+numMapmm_r+" < "+numMapmm, numMapmm_r < numMapmm);
+                       if (testId == 1) {
+                               Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
+                               Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
+                       }
+                       if (testId == 2)
+                               Assert.assertTrue("Violated prefetch reuse 
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
 
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
@@ -120,4 +147,16 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        Recompiler.reinitRecompiler();
                }
        }
+
+       private void enablePrefetch() {
+               OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
+               OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+               OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+       }
+
+       private void disablePrefetch() {
+               OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
+               OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+               OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+       }
 }
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml 
b/src/test/scripts/functions/async/LineageReuseSpark2.dml
new file mode 100644
index 0000000000..63792332b4
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark2.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# Spark transformation operations 
+for (i in 1:10) {
+  while(FALSE){}
+  sp1 = X + ceil(X);
+  sp2 = sp1 %*% v; #output fits in local
+  # Place a prefetch after mapmm and reuse
+
+  # CP instructions
+  v2 = ((v + v) * 1 - v) / (1+1);
+  v2 = ((v + v) * 2 - v) / (2+1);
+
+  # CP binary triggers the DAG of SP operations
+  cp = sp2 + sum(v2);
+  R = sum(cp);
+}
+
+write(R, $1, format="text");

Reply via email to