This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c54213d [SYSTEMDS-2745] Fix indexed addition assignment (accumulation)
c54213d is described below
commit c54213df08b259fc3b8c96d4c3ffe6b0ea6b1eb1
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Dec 19 19:08:51 2020 +0100
[SYSTEMDS-2745] Fix indexed addition assignment (accumulation)
This patch adds the missing support for addition assignments in left
indexing expressions for both scalars and matrices as well as scalar and
matrix indexed ranges.
Thanks to Rene Haubitzer for catching this issue.
---
.../org/apache/sysds/parser/DMLTranslator.java | 133 +++++++++------------
.../indexing/IndexedAdditionAssignmentTest.java | 91 ++++++++++++++
.../functions/indexing/LeftIndexingScalarTest.java | 38 ++----
.../functions/indexing/IndexedAdditionTest.dml | 31 +++++
4 files changed, 187 insertions(+), 106 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index ff41df6..aab0d22 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1137,11 +1137,8 @@ public class DMLTranslator
if (!(target instanceof
IndexedIdentifier)) {
//process right hand side and
accumulation
Hop ae =
processExpression(source, target, ids);
- if(
((AssignmentStatement)current).isAccumulator() ) {
- DataIdentifier accum =
liveIn.getVariable(target.getName());
- if( accum == null )
- throw new
LanguageException("Invalid accumulator assignment "
- + "to
non-existing variable "+target.getName()+".");
+ if( as.isAccumulator() ) {
+ DataIdentifier accum =
getAccumulatorData(liveIn, target.getName());
ae =
HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
target.setProperties(accum.getOutput());
}
@@ -1170,6 +1167,15 @@ public class DMLTranslator
else {
Hop ae =
processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
+ if( as.isAccumulator() ) {
+ DataIdentifier accum =
getAccumulatorData(liveIn, target.getName());
+ Hop rix =
processIndexingExpression((IndexedIdentifier)target, null, ids);
+ Hop rhs =
processExpression(source, null, ids);
+ Hop binary =
HopRewriteUtils.createBinary(rix, rhs, OpOp2.PLUS);
+
HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary);
+
target.setProperties(accum.getOutput());
+ }
+
ids.put(target.getName(), ae);
// obtain origDim values BEFORE
they are potentially updated during setProperties call
@@ -1298,7 +1304,14 @@ public class DMLTranslator
}
sb.updateLiveVariablesOut(updatedLiveOut);
sb.setHops(output);
-
+ }
+
+ private static DataIdentifier getAccumulatorData(VariableSet liveIn,
String varname) {
+ DataIdentifier accum = liveIn.getVariable(varname);
+ if( accum == null )
+ throw new LanguageException("Invalid accumulator
assignment "
+ + "to non-existing variable "+varname+".");
+ return accum;
}
private void appendDefaultArguments(FunctionStatement fstmt,
List<String> inputNames, List<Hop> inputs, HashMap<String, Hop> ids) {
@@ -1630,41 +1643,9 @@ public class DMLTranslator
return processExpression(source, tmpOut, hops );
}
- private Hop processLeftIndexedExpression(Expression source,
IndexedIdentifier target, HashMap<String, Hop> hops)
- {
+ private Hop processLeftIndexedExpression(Expression source,
IndexedIdentifier target, HashMap<String, Hop> hops) {
// process target indexed expressions
- Hop rowLowerHops = null, rowUpperHops = null, colLowerHops =
null, colUpperHops = null;
-
- if (target.getRowLowerBound() != null)
- rowLowerHops =
processExpression(target.getRowLowerBound(),null,hops);
- else
- rowLowerHops = new LiteralOp(1);
-
- if (target.getRowUpperBound() != null)
- rowUpperHops =
processExpression(target.getRowUpperBound(),null,hops);
- else
- {
- if ( target.getDim1() != -1 )
- rowUpperHops = new
LiteralOp(target.getOrigDim1());
- else {
- rowUpperHops = new UnaryOp(target.getName(),
DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(target.getName()));
- rowUpperHops.setParseInfo(target);
- }
- }
- if (target.getColLowerBound() != null)
- colLowerHops =
processExpression(target.getColLowerBound(),null,hops);
- else
- colLowerHops = new LiteralOp(1);
-
- if (target.getColUpperBound() != null)
- colUpperHops =
processExpression(target.getColUpperBound(),null,hops);
- else
- {
- if ( target.getDim2() != -1 )
- colUpperHops = new
LiteralOp(target.getOrigDim2());
- else
- colUpperHops = new UnaryOp(target.getName(),
DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(target.getName()));
- }
+ Hop[] ixRange = getIndexingBounds(target, hops, true);
// process the source expression to get source Hops
Hop sourceOp = processExpression(source, target, hops);
@@ -1678,12 +1659,11 @@ public class DMLTranslator
if( sourceOp.getDataType().isMatrix() &&
source.getOutput().getDataType().isScalar() )
sourceOp.setDataType(DataType.SCALAR);
- Hop leftIndexOp = new LeftIndexingOp(target.getName(),
target.getDataType(), ValueType.FP64,
- targetOp, sourceOp, rowLowerHops, rowUpperHops,
colLowerHops, colUpperHops,
- target.getRowLowerEqualsUpper(),
target.getColLowerEqualsUpper());
+ Hop leftIndexOp = new LeftIndexingOp(target.getName(),
target.getDataType(),
+ ValueType.FP64, targetOp, sourceOp, ixRange[0],
ixRange[1], ixRange[2], ixRange[3],
+ target.getRowLowerEqualsUpper(),
target.getColLowerEqualsUpper());
setIdentifierParams(leftIndexOp, target);
-
leftIndexOp.setParseInfo(target);
leftIndexOp.setDim1(target.getOrigDim1());
leftIndexOp.setDim2(target.getOrigDim2());
@@ -1694,38 +1674,7 @@ public class DMLTranslator
private Hop processIndexingExpression(IndexedIdentifier source,
DataIdentifier target, HashMap<String, Hop> hops) {
// process Hops for indexes (for source)
- Hop rowLowerHops = null, rowUpperHops = null, colLowerHops =
null, colUpperHops = null;
-
- if (source.getRowLowerBound() != null)
- rowLowerHops =
processExpression(source.getRowLowerBound(),null,hops);
- else
- rowLowerHops = new LiteralOp(1);
-
- if (source.getRowUpperBound() != null)
- rowUpperHops =
processExpression(source.getRowUpperBound(),null,hops);
- else
- {
- if ( source.getOrigDim1() != -1 )
- rowUpperHops = new
LiteralOp(source.getOrigDim1());
- else {
- rowUpperHops = new UnaryOp(source.getName(),
DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(source.getName()));
- rowUpperHops.setParseInfo(source);
- }
- }
- if (source.getColLowerBound() != null)
- colLowerHops =
processExpression(source.getColLowerBound(),null,hops);
- else
- colLowerHops = new LiteralOp(1);
-
- if (source.getColUpperBound() != null)
- colUpperHops =
processExpression(source.getColUpperBound(),null,hops);
- else
- {
- if ( source.getOrigDim2() != -1 )
- colUpperHops = new
LiteralOp(source.getOrigDim2());
- else
- colUpperHops = new UnaryOp(source.getName(),
DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(source.getName()));
- }
+ Hop[] ixRange = getIndexingBounds(source, hops, false);
if (target == null) {
target = createTarget(source);
@@ -1735,8 +1684,8 @@ public class DMLTranslator
target.setNnz(-1);
Hop indexOp = new IndexingOp(target.getName(),
target.getDataType(), target.getValueType(),
- hops.get(source.getName()), rowLowerHops,
rowUpperHops, colLowerHops, colUpperHops,
- source.getRowLowerEqualsUpper(),
source.getColLowerEqualsUpper());
+ hops.get(source.getName()), ixRange[0], ixRange[1],
ixRange[2], ixRange[3],
+ source.getRowLowerEqualsUpper(),
source.getColLowerEqualsUpper());
indexOp.setParseInfo(target);
setIdentifierParams(indexOp, target);
@@ -1744,6 +1693,34 @@ public class DMLTranslator
return indexOp;
}
+ private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap<String,
Hop> hops, boolean lix) {
+ Hop rowLowerHops = (ix.getRowLowerBound() != null) ?
+ processExpression(ix.getRowLowerBound(),null, hops) :
new LiteralOp(1);
+ Hop colLowerHops = (ix.getColLowerBound() != null) ?
+ processExpression(ix.getColLowerBound(),null, hops) :
new LiteralOp(1);
+
+ Hop rowUpperHops = null, colUpperHops = null;
+ if (ix.getRowUpperBound() != null)
+ rowUpperHops =
processExpression(ix.getRowUpperBound(),null,hops);
+ else {
+ rowUpperHops = ((lix ? ix.getDim1() : ix.getOrigDim1())
!= -1) ?
+ new LiteralOp(ix.getOrigDim1()) :
+ new UnaryOp(ix.getName(), DataType.SCALAR,
ValueType.INT64, OpOp1.NROW, hops.get(ix.getName()));
+ rowUpperHops.setParseInfo(ix);
+ }
+
+ if (ix.getColUpperBound() != null)
+ colUpperHops =
processExpression(ix.getColUpperBound(),null,hops);
+ else {
+ colUpperHops = ((lix ? ix.getDim2() : ix.getOrigDim2())
!= -1) ?
+ new LiteralOp(ix.getOrigDim2()) :
+ new UnaryOp(ix.getName(), DataType.SCALAR,
ValueType.INT64, OpOp1.NCOL, hops.get(ix.getName()));
+ colUpperHops.setParseInfo(ix);
+ }
+
+ return new Hop[] {rowLowerHops, rowUpperHops, colLowerHops,
colUpperHops};
+ }
+
/**
* Construct Hops from parse tree : Process Binary Expression in an
diff --git
a/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
new file mode 100644
index 0000000..3db2535
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.indexing;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class IndexedAdditionAssignmentTest extends AutomatedTestBase
+{
+ private final static String TEST_DIR = "functions/indexing/";
+ private final static String TEST_NAME = "IndexedAdditionTest";
+
+ private final static String TEST_CLASS_DIR = TEST_DIR +
IndexedAdditionAssignmentTest.class.getSimpleName() + "/";
+
+ private final static int rows = 1279;
+ private final static int cols = 1050;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"}));
+ }
+
+ @Test
+ public void testIndexedAssignmentAddScalarCP() {
+ runIndexedAdditionAssignment(true, ExecType.CP);
+ }
+
+ @Test
+ public void testIndexedAssignmentAddMatrixCP() {
+ runIndexedAdditionAssignment(false, ExecType.CP);
+ }
+
+ @Test
+ public void testIndexedAssignmentAddScalarSpark() {
+ runIndexedAdditionAssignment(true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testIndexedAssignmentAddMatrixSpark() {
+ runIndexedAdditionAssignment(false, ExecType.SPARK);
+ }
+
+ private void runIndexedAdditionAssignment(boolean scalar, ExecType
instType) {
+ ExecMode platformOld = setExecMode(instType);
+
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ //test is adding or subtracting 7 to area 1x1 or 10x10
+ //of an initially constraint (3) matrix and sums it up
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain" , "-args",
+ Long.toString(rows), Long.toString(cols),
+ String.valueOf(scalar).toUpperCase(),
output("A")};
+
+ runTest(true, false, null, -1);
+
+ Double ret = readDMLMatrixFromOutputDir("A").get(new
CellIndex(1,1));
+ Assert.assertEquals(new Double(3*rows*cols +
7*(scalar?1:100)), ret);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
index b5ea0aa..68fbc37 100644
---
a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.indexing;
import java.util.HashMap;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -33,7 +32,6 @@ import org.apache.sysds.test.TestUtils;
public class LeftIndexingScalarTest extends AutomatedTestBase
{
-
private final static String TEST_DIR = "functions/indexing/";
private final static String TEST_NAME = "LeftIndexingScalarTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
LeftIndexingScalarTest.class.getSimpleName() + "/";
@@ -52,31 +50,18 @@ public class LeftIndexingScalarTest extends
AutomatedTestBase
}
@Test
- public void testLeftIndexingScalarCP()
- {
+ public void testLeftIndexingScalarCP() {
runLeftIndexingTest(ExecType.CP);
}
@Test
- public void testLeftIndexingScalarSP()
- {
+ public void testLeftIndexingScalarSP() {
runLeftIndexingTest(ExecType.SPARK);
}
private void runLeftIndexingTest( ExecType instType )
- {
- //rtplatform for MR
- ExecMode platformOld = rtplatform;
- if(instType == ExecType.SPARK) {
- rtplatform = ExecMode.SPARK;
- }
- else {
- rtplatform = ExecMode.HYBRID;
- }
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ {
+ ExecMode platformOld = setExecMode(instType);
try
{
@@ -91,10 +76,10 @@ public class LeftIndexingScalarTest extends
AutomatedTestBase
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
- double[][] A = getRandomMatrix(rows, cols, min, max, sparsity,
System.currentTimeMillis());
- writeInputMatrix("A", A, true);
-
- runTest(true, false, null, -1);
+ double[][] A = getRandomMatrix(rows, cols, min, max,
sparsity, System.currentTimeMillis());
+ writeInputMatrix("A", A, true);
+
+ runTest(true, false, null, -1);
runRScript(true);
HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("A");
@@ -102,11 +87,8 @@ public class LeftIndexingScalarTest extends
AutomatedTestBase
TestUtils.compareMatrices(dmlfile, rfile, epsilon,
"A-DML", "A-R");
checkDMLMetaDataFile("A", new
MatrixCharacteristics(rows,cols,1,1));
}
- finally
- {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ finally {
+ resetExecMode(platformOld);
}
}
}
-
diff --git a/src/test/scripts/functions/indexing/IndexedAdditionTest.dml
b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml
new file mode 100644
index 0000000..415a795
--- /dev/null
+++ b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = matrix(3, $1, $2);
+
+if( $3 )
+ A[10,20] += 7;
+else
+ A[10:19,20:29] += 7;
+
+R = as.matrix(sum(A))
+write(R, $4, format="text")