This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c54213d  [SYSTEMDS-2745] Fix indexed addition assignment (accumulation)
c54213d is described below

commit c54213df08b259fc3b8c96d4c3ffe6b0ea6b1eb1
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Dec 19 19:08:51 2020 +0100

    [SYSTEMDS-2745] Fix indexed addition assignment (accumulation)
    
    This patch adds the missing support for addition assignments in left
    indexing expressions for both scalars and matrices as well as scalar and
    matrix indexed ranges.
    
    Thanks to Rene Haubitzer for catching this issue.
---
 .../org/apache/sysds/parser/DMLTranslator.java     | 133 +++++++++------------
 .../indexing/IndexedAdditionAssignmentTest.java    |  91 ++++++++++++++
 .../functions/indexing/LeftIndexingScalarTest.java |  38 ++----
 .../functions/indexing/IndexedAdditionTest.dml     |  31 +++++
 4 files changed, 187 insertions(+), 106 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index ff41df6..aab0d22 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1137,11 +1137,8 @@ public class DMLTranslator
                                        if (!(target instanceof 
IndexedIdentifier)) {
                                                //process right hand side and 
accumulation
                                                Hop ae = 
processExpression(source, target, ids);
-                                               if( 
((AssignmentStatement)current).isAccumulator() ) {
-                                                       DataIdentifier accum = 
liveIn.getVariable(target.getName());
-                                                       if( accum == null )
-                                                               throw new 
LanguageException("Invalid accumulator assignment "
-                                                                       + "to 
non-existing variable "+target.getName()+".");
+                                               if( as.isAccumulator() ) {
+                                                       DataIdentifier accum = 
getAccumulatorData(liveIn, target.getName());
                                                        ae = 
HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
                                                        
target.setProperties(accum.getOutput());
                                                }
@@ -1170,6 +1167,15 @@ public class DMLTranslator
                                        else {
                                                Hop ae = 
processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
                                                
+                                               if( as.isAccumulator() ) {
+                                                       DataIdentifier accum = 
getAccumulatorData(liveIn, target.getName());
+                                                       Hop rix = 
processIndexingExpression((IndexedIdentifier)target, null, ids);
+                                                       Hop rhs = 
processExpression(source, null, ids);
+                                                       Hop binary = 
HopRewriteUtils.createBinary(rix, rhs, OpOp2.PLUS);
+                                                       
HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary);
+                                                       
target.setProperties(accum.getOutput());
+                                               }
+                                               
                                                ids.put(target.getName(), ae);
                                                
                                                // obtain origDim values BEFORE 
they are potentially updated during setProperties call
@@ -1298,7 +1304,14 @@ public class DMLTranslator
                }
                sb.updateLiveVariablesOut(updatedLiveOut);
                sb.setHops(output);
-
+       }
+       
+       private static DataIdentifier getAccumulatorData(VariableSet liveIn, 
String varname) {
+               DataIdentifier accum = liveIn.getVariable(varname);
+               if( accum == null )
+                       throw new LanguageException("Invalid accumulator 
assignment "
+                               + "to non-existing variable "+varname+".");
+               return accum;
        }
        
        private void appendDefaultArguments(FunctionStatement fstmt, 
List<String> inputNames, List<Hop> inputs, HashMap<String, Hop> ids) {
@@ -1630,41 +1643,9 @@ public class DMLTranslator
                return processExpression(source, tmpOut, hops );
        }
        
-       private Hop processLeftIndexedExpression(Expression source, 
IndexedIdentifier target, HashMap<String, Hop> hops)  
-       {
+       private Hop processLeftIndexedExpression(Expression source, 
IndexedIdentifier target, HashMap<String, Hop> hops) {
                // process target indexed expressions
-               Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = 
null, colUpperHops = null;
-               
-               if (target.getRowLowerBound() != null)
-                       rowLowerHops = 
processExpression(target.getRowLowerBound(),null,hops);
-               else
-                       rowLowerHops = new LiteralOp(1);
-               
-               if (target.getRowUpperBound() != null)
-                       rowUpperHops = 
processExpression(target.getRowUpperBound(),null,hops);
-               else
-               {
-                       if ( target.getDim1() != -1 ) 
-                               rowUpperHops = new 
LiteralOp(target.getOrigDim1());
-                       else {
-                               rowUpperHops = new UnaryOp(target.getName(), 
DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(target.getName()));
-                               rowUpperHops.setParseInfo(target);
-                       }
-               }
-               if (target.getColLowerBound() != null)
-                       colLowerHops = 
processExpression(target.getColLowerBound(),null,hops);
-               else
-                       colLowerHops = new LiteralOp(1);
-               
-               if (target.getColUpperBound() != null)
-                       colUpperHops = 
processExpression(target.getColUpperBound(),null,hops);
-               else
-               {
-                       if ( target.getDim2() != -1 ) 
-                               colUpperHops = new 
LiteralOp(target.getOrigDim2());
-                       else
-                               colUpperHops = new UnaryOp(target.getName(), 
DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(target.getName()));
-               }
+               Hop[] ixRange = getIndexingBounds(target, hops, true);
                
                // process the source expression to get source Hops
                Hop sourceOp = processExpression(source, target, hops);
@@ -1678,12 +1659,11 @@ public class DMLTranslator
                if( sourceOp.getDataType().isMatrix() && 
source.getOutput().getDataType().isScalar() )
                        sourceOp.setDataType(DataType.SCALAR);
                
-               Hop leftIndexOp = new LeftIndexingOp(target.getName(), 
target.getDataType(), ValueType.FP64, 
-                               targetOp, sourceOp, rowLowerHops, rowUpperHops, 
colLowerHops, colUpperHops, 
-                               target.getRowLowerEqualsUpper(), 
target.getColLowerEqualsUpper());
+               Hop leftIndexOp = new LeftIndexingOp(target.getName(), 
target.getDataType(),
+                       ValueType.FP64, targetOp, sourceOp, ixRange[0], 
ixRange[1], ixRange[2], ixRange[3],
+                       target.getRowLowerEqualsUpper(), 
target.getColLowerEqualsUpper());
                
                setIdentifierParams(leftIndexOp, target);
-       
                leftIndexOp.setParseInfo(target);
                leftIndexOp.setDim1(target.getOrigDim1());
                leftIndexOp.setDim2(target.getOrigDim2());
@@ -1694,38 +1674,7 @@ public class DMLTranslator
        
        private Hop processIndexingExpression(IndexedIdentifier source, 
DataIdentifier target, HashMap<String, Hop> hops) {
                // process Hops for indexes (for source)
-               Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = 
null, colUpperHops = null;
-               
-               if (source.getRowLowerBound() != null)
-                       rowLowerHops = 
processExpression(source.getRowLowerBound(),null,hops);
-               else
-                       rowLowerHops = new LiteralOp(1);
-               
-               if (source.getRowUpperBound() != null)
-                       rowUpperHops = 
processExpression(source.getRowUpperBound(),null,hops);
-               else
-               {
-                       if ( source.getOrigDim1() != -1 ) 
-                               rowUpperHops = new 
LiteralOp(source.getOrigDim1());
-                       else {
-                               rowUpperHops = new UnaryOp(source.getName(), 
DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(source.getName()));
-                               rowUpperHops.setParseInfo(source);
-                       }
-               }
-               if (source.getColLowerBound() != null)
-                       colLowerHops = 
processExpression(source.getColLowerBound(),null,hops);
-               else
-                       colLowerHops = new LiteralOp(1);
-               
-               if (source.getColUpperBound() != null)
-                       colUpperHops = 
processExpression(source.getColUpperBound(),null,hops);
-               else
-               {
-                       if ( source.getOrigDim2() != -1 ) 
-                               colUpperHops = new 
LiteralOp(source.getOrigDim2());
-                       else
-                               colUpperHops = new UnaryOp(source.getName(), 
DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(source.getName()));
-               }
+               Hop[] ixRange = getIndexingBounds(source, hops, false);
                
                if (target == null) {
                        target = createTarget(source);
@@ -1735,8 +1684,8 @@ public class DMLTranslator
                target.setNnz(-1); 
                
                Hop indexOp = new IndexingOp(target.getName(), 
target.getDataType(), target.getValueType(),
-                               hops.get(source.getName()), rowLowerHops, 
rowUpperHops, colLowerHops, colUpperHops,
-                               source.getRowLowerEqualsUpper(), 
source.getColLowerEqualsUpper());
+                       hops.get(source.getName()), ixRange[0], ixRange[1], 
ixRange[2], ixRange[3],
+                       source.getRowLowerEqualsUpper(), 
source.getColLowerEqualsUpper());
        
                indexOp.setParseInfo(target);
                setIdentifierParams(indexOp, target);
@@ -1744,6 +1693,34 @@ public class DMLTranslator
                return indexOp;
        }
        
+       private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap<String, 
Hop> hops, boolean lix) {
+               Hop rowLowerHops = (ix.getRowLowerBound() != null) ?
+                       processExpression(ix.getRowLowerBound(),null, hops) : 
new LiteralOp(1);
+               Hop colLowerHops = (ix.getColLowerBound() != null) ?
+                       processExpression(ix.getColLowerBound(),null, hops) : 
new LiteralOp(1);
+               
+               Hop rowUpperHops = null, colUpperHops = null;
+               if (ix.getRowUpperBound() != null)
+                       rowUpperHops = 
processExpression(ix.getRowUpperBound(),null,hops);
+               else {
+                       rowUpperHops = ((lix ? ix.getDim1() : ix.getOrigDim1()) 
!= -1) ?
+                               new LiteralOp(ix.getOrigDim1()) :
+                               new UnaryOp(ix.getName(), DataType.SCALAR, 
ValueType.INT64, OpOp1.NROW, hops.get(ix.getName()));
+                       rowUpperHops.setParseInfo(ix);
+               }
+               
+               if (ix.getColUpperBound() != null)
+                       colUpperHops = 
processExpression(ix.getColUpperBound(),null,hops);
+               else {
+                       colUpperHops = ((lix ? ix.getDim2() : ix.getOrigDim2()) 
!= -1) ?
+                               new LiteralOp(ix.getOrigDim2()) :
+                               new UnaryOp(ix.getName(), DataType.SCALAR, 
ValueType.INT64, OpOp1.NCOL, hops.get(ix.getName()));
+                       colUpperHops.setParseInfo(ix);
+               }
+               
+               return new Hop[] {rowLowerHops, rowUpperHops, colLowerHops, 
colUpperHops};
+       }
+       
        
        /**
         * Construct Hops from parse tree : Process Binary Expression in an
diff --git 
a/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
 
b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
new file mode 100644
index 0000000..3db2535
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.indexing;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class IndexedAdditionAssignmentTest extends AutomatedTestBase
+{
+       private final static String TEST_DIR = "functions/indexing/";
+       private final static String TEST_NAME = "IndexedAdditionTest";
+       
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
IndexedAdditionAssignmentTest.class.getSimpleName() + "/";
+       
+       private final static int rows = 1279;
+       private final static int cols = 1050;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"}));
+       }
+
+       @Test
+       public void testIndexedAssignmentAddScalarCP() {
+               runIndexedAdditionAssignment(true, ExecType.CP);
+       }
+       
+       @Test
+       public void testIndexedAssignmentAddMatrixCP() {
+               runIndexedAdditionAssignment(false, ExecType.CP);
+       }
+       
+       @Test
+       public void testIndexedAssignmentAddScalarSpark() {
+               runIndexedAdditionAssignment(true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testIndexedAssignmentAddMatrixSpark() {
+               runIndexedAdditionAssignment(false, ExecType.SPARK);
+       }
+       
+       private void runIndexedAdditionAssignment(boolean scalar, ExecType 
instType) {
+               ExecMode platformOld = setExecMode(instType);
+       
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+                       
+                       //test is adding or subtracting 7 to area 1x1 or 10x10
+                       //of an initially constraint (3) matrix and sums it up
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain" , "-args",
+                               Long.toString(rows), Long.toString(cols),
+                               String.valueOf(scalar).toUpperCase(), 
output("A")};
+                       
+                       runTest(true, false, null, -1);
+                       
+                       Double ret = readDMLMatrixFromOutputDir("A").get(new 
CellIndex(1,1));
+                       Assert.assertEquals(new Double(3*rows*cols + 
7*(scalar?1:100)),  ret);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
 
b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
index b5ea0aa..68fbc37 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.indexing;
 import java.util.HashMap;
 
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -33,7 +32,6 @@ import org.apache.sysds.test.TestUtils;
 
 public class LeftIndexingScalarTest extends AutomatedTestBase
 {
-
        private final static String TEST_DIR = "functions/indexing/";
        private final static String TEST_NAME = "LeftIndexingScalarTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
LeftIndexingScalarTest.class.getSimpleName() + "/";
@@ -52,31 +50,18 @@ public class LeftIndexingScalarTest extends 
AutomatedTestBase
        }
 
        @Test
-       public void testLeftIndexingScalarCP() 
-       {
+       public void testLeftIndexingScalarCP() {
                runLeftIndexingTest(ExecType.CP);
        }
        
        @Test
-       public void testLeftIndexingScalarSP() 
-       {
+       public void testLeftIndexingScalarSP() {
                runLeftIndexingTest(ExecType.SPARK);
        }
        
        private void runLeftIndexingTest( ExecType instType ) 
-       {               
-               //rtplatform for MR
-               ExecMode platformOld = rtplatform;
-               if(instType == ExecType.SPARK) {
-               rtplatform = ExecMode.SPARK;
-           }
-           else {
-                       rtplatform = ExecMode.HYBRID;
-           }
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-               
+       {
+               ExecMode platformOld = setExecMode(instType);
        
                try
                {
@@ -91,10 +76,10 @@ public class LeftIndexingScalarTest extends 
AutomatedTestBase
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
                        
-               double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 
System.currentTimeMillis());
-               writeInputMatrix("A", A, true);
-              
-               runTest(true, false, null, -1);         
+                       double[][] A = getRandomMatrix(rows, cols, min, max, 
sparsity, System.currentTimeMillis());
+                       writeInputMatrix("A", A, true);
+
+                       runTest(true, false, null, -1);
                        runRScript(true);
                        
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("A");
@@ -102,11 +87,8 @@ public class LeftIndexingScalarTest extends 
AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, epsilon, 
"A-DML", "A-R");
                        checkDMLMetaDataFile("A", new 
MatrixCharacteristics(rows,cols,1,1));
                }
-               finally
-               {
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               finally {
+                       resetExecMode(platformOld);
                }
        }
 }
-
diff --git a/src/test/scripts/functions/indexing/IndexedAdditionTest.dml 
b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml
new file mode 100644
index 0000000..415a795
--- /dev/null
+++ b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = matrix(3, $1, $2);
+
+if( $3 )
+  A[10,20] += 7;
+else
+  A[10:19,20:29] += 7;
+
+R = as.matrix(sum(A))
+write(R, $4, format="text")

Reply via email to