This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a85d52  [SYSTEMDS-2650] Fix capturaing inputs to a loop body
8a85d52 is described below

commit 8a85d529d3c6670dc04915e0e672aae14c343307
Author: arnabp <[email protected]>
AuthorDate: Mon Sep 7 16:48:33 2020 +0200

    [SYSTEMDS-2650] Fix capturaing inputs to a loop body
---
 .../org/apache/sysds/parser/ForStatementBlock.java | 11 ++++++
 .../apache/sysds/parser/WhileStatementBlock.java   | 11 ++++++
 .../org/apache/sysds/runtime/lineage/Lineage.java  | 11 ------
 .../apache/sysds/runtime/lineage/LineageMap.java   |  4 +--
 .../functions/lineage/LineageTraceDedupTest.java   |  8 ++++-
 .../functions/lineage/LineageTraceDedup11.dml      | 39 ++++++++++++++++++++++
 6 files changed, 70 insertions(+), 14 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 03368d3..092fbb7 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -278,6 +278,17 @@ public class ForStatementBlock extends StatementBlock
        public Lop getToLops()        { return _toLops; }
        public Lop getIncrementLops() { return _incrementLops; }
 
+       public ArrayList<String> getInputstoSB() {
+               // By calling getInputstoSB on all the child statement blocks,
+               // we remove the variables only read in the for predicate but
+               // never used in the body from the input list.
+               ArrayList<String> inputs = new ArrayList<>();
+               ForStatement fstmt = (ForStatement)_statements.get(0);
+               for (StatementBlock sb : fstmt.getBody())
+                       inputs.addAll(sb.getInputstoSB());
+               return inputs;
+       }
+
        @Override
        public VariableSet analyze(VariableSet loPassed) {
                
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index 18a47ad..89f9261 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -255,6 +255,17 @@ public class WhileStatementBlock extends StatementBlock
                _predicateLops = predicateLops;
        }
        
+       public ArrayList<String> getInputstoSB() {
+               // By calling getInputstoSB on all the child statement blocks,
+               // we remove the variables only read in the while predicate but
+               // never used in the body from the input list.
+               ArrayList<String> inputs = new ArrayList<>();
+               WhileStatement fstmt = (WhileStatement)_statements.get(0);
+               for (StatementBlock sb : fstmt.getBody())
+                       inputs.addAll(sb.getInputstoSB());
+               return inputs;
+       }
+       
        @Override
        public VariableSet analyze(VariableSet loPassed) {
                VariableSet predVars = new VariableSet();
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java 
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index 8ee2e2b..5a8a922 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -125,17 +125,6 @@ public class Lineage {
                _initDedupBlock = ldb;
        }
        
-       public void computeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
-               if( !(pb instanceof ForProgramBlock || pb instanceof 
WhileProgramBlock) )
-                       throw new DMLRuntimeException("Invalid deduplication 
block: "+ pb.getClass().getSimpleName());
-               if (!_dedupBlocks.containsKey(pb)) {
-                       boolean valid = LineageDedupUtils.isValidDedupBlock(pb, 
false);
-                       _dedupBlocks.put(pb, valid?
-                               LineageDedupUtils.computeDedupBlock(pb, ec) : 
null);
-               }
-               _activeDedupBlock = _dedupBlocks.get(pb); //null if invalid
-       }
-
        public void initializeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
                if( !(pb instanceof ForProgramBlock || pb instanceof 
WhileProgramBlock) )
                        throw new DMLRuntimeException("Invalid deduplication 
block: "+ pb.getClass().getSimpleName());
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
index 1558beb..47af4f5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
@@ -75,8 +75,8 @@ public class LineageMap {
        public void processDedupItem(LineageMap lm, Long path, LineageItem[] 
liinputs, String name) {
                String delim = LineageDedupUtils.DEDUP_DELIM;
                for (Map.Entry<String, LineageItem> entry : 
lm._traces.entrySet()) {
-                       // Encode everything in the opcode needed by the 
deserialization logic
-                       // to map this lineage item to the right patch.
+                       // Encode everything needed by the recomputation logic 
in the
+                       // opcode to map this lineage item to the right patch.
                        String opcode = LineageItem.dedupItemOpcode + delim + 
entry.getKey()
                                + delim + name + delim + path.toString();
                        LineageItem li = new LineageItem(opcode, liinputs);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 3b1ae65..90bd1f2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -51,6 +51,7 @@ public class LineageTraceDedupTest extends AutomatedTestBase
        protected static final String TEST_NAME7 = "LineageTraceDedup7"; 
//nested if-else branches
        protected static final String TEST_NAME8 = "LineageTraceDedup8"; 
//while loop
        protected static final String TEST_NAME9 = "LineageTraceDedup9"; 
//while loop w/ if
+       protected static final String TEST_NAME11 = "LineageTraceDedup11"; 
//mini-batch
        
        protected String TEST_CLASS_DIR = TEST_DIR + 
LineageTraceDedupTest.class.getSimpleName() + "/";
        
@@ -61,7 +62,7 @@ public class LineageTraceDedupTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<11; i++)
+               for(int i=1; i<=11; i++)
                        addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
        }
        
@@ -114,6 +115,11 @@ public class LineageTraceDedupTest extends 
AutomatedTestBase
        public void testLineageTrace9() {
                testLineageTrace(TEST_NAME9);
        }
+
+       @Test
+       public void testLineageTrace11() {
+               testLineageTrace(TEST_NAME11);
+       }
        
        public void testLineageTrace(String testname) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
diff --git a/src/test/scripts/functions/lineage/LineageTraceDedup11.dml 
b/src/test/scripts/functions/lineage/LineageTraceDedup11.dml
new file mode 100644
index 0000000..722af6b
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageTraceDedup11.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 8;
+lim = 100;
+X = rand(rows=M, cols=784, seed=42)
+
+for(i in 1:lim) {
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+}
+write(X, $1, format="text");
+

Reply via email to