This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new c54213d [SYSTEMDS-2745] Fix indexed addition assignment (accumulation) c54213d is described below commit c54213df08b259fc3b8c96d4c3ffe6b0ea6b1eb1 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sat Dec 19 19:08:51 2020 +0100 [SYSTEMDS-2745] Fix indexed addition assignment (accumulation) This patch adds the missing support for addition assignments in left indexing expressions for both scalars and matrices as well as scalar and matrix indexed ranges. Thanks to Rene Haubitzer for catching this issue. --- .../org/apache/sysds/parser/DMLTranslator.java | 133 +++++++++------------ .../indexing/IndexedAdditionAssignmentTest.java | 91 ++++++++++++++ .../functions/indexing/LeftIndexingScalarTest.java | 38 ++---- .../functions/indexing/IndexedAdditionTest.dml | 31 +++++ 4 files changed, 187 insertions(+), 106 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index ff41df6..aab0d22 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -1137,11 +1137,8 @@ public class DMLTranslator if (!(target instanceof IndexedIdentifier)) { //process right hand side and accumulation Hop ae = processExpression(source, target, ids); - if( ((AssignmentStatement)current).isAccumulator() ) { - DataIdentifier accum = liveIn.getVariable(target.getName()); - if( accum == null ) - throw new LanguageException("Invalid accumulator assignment " - + "to non-existing variable "+target.getName()+"."); + if( as.isAccumulator() ) { + DataIdentifier accum = getAccumulatorData(liveIn, target.getName()); ae = HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS); target.setProperties(accum.getOutput()); } @@ -1170,6 +1167,15 @@ public class DMLTranslator else { Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids); + if( as.isAccumulator() ) { + DataIdentifier accum = getAccumulatorData(liveIn, target.getName()); + Hop rix = processIndexingExpression((IndexedIdentifier)target, null, ids); + Hop rhs = processExpression(source, null, ids); + Hop binary = HopRewriteUtils.createBinary(rix, rhs, OpOp2.PLUS); + HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary); + target.setProperties(accum.getOutput()); + } + ids.put(target.getName(), ae); // obtain origDim values BEFORE they are potentially updated during setProperties call @@ -1298,7 +1304,14 @@ public class DMLTranslator } sb.updateLiveVariablesOut(updatedLiveOut); sb.setHops(output); - + } + + private static DataIdentifier getAccumulatorData(VariableSet liveIn, String varname) { + DataIdentifier accum = liveIn.getVariable(varname); + if( accum == null ) + throw new LanguageException("Invalid accumulator assignment " + + "to non-existing variable "+varname+"."); + return accum; } private void appendDefaultArguments(FunctionStatement fstmt, List<String> inputNames, List<Hop> inputs, HashMap<String, Hop> ids) { @@ -1630,41 +1643,9 @@ public class DMLTranslator return processExpression(source, tmpOut, hops ); } - private Hop processLeftIndexedExpression(Expression source, IndexedIdentifier target, HashMap<String, Hop> hops) - { + private Hop processLeftIndexedExpression(Expression source, IndexedIdentifier target, HashMap<String, Hop> hops) { // process target indexed expressions - Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null; - - if (target.getRowLowerBound() != null) - rowLowerHops = processExpression(target.getRowLowerBound(),null,hops); - else - rowLowerHops = new LiteralOp(1); - - if (target.getRowUpperBound() != null) - rowUpperHops = processExpression(target.getRowUpperBound(),null,hops); - else - { - if ( target.getDim1() != -1 ) - rowUpperHops = new LiteralOp(target.getOrigDim1()); - else { - rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(target.getName())); - rowUpperHops.setParseInfo(target); - } - } - if (target.getColLowerBound() != null) - colLowerHops = processExpression(target.getColLowerBound(),null,hops); - else - colLowerHops = new LiteralOp(1); - - if (target.getColUpperBound() != null) - colUpperHops = processExpression(target.getColUpperBound(),null,hops); - else - { - if ( target.getDim2() != -1 ) - colUpperHops = new LiteralOp(target.getOrigDim2()); - else - colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(target.getName())); - } + Hop[] ixRange = getIndexingBounds(target, hops, true); // process the source expression to get source Hops Hop sourceOp = processExpression(source, target, hops); @@ -1678,12 +1659,11 @@ public class DMLTranslator if( sourceOp.getDataType().isMatrix() && source.getOutput().getDataType().isScalar() ) sourceOp.setDataType(DataType.SCALAR); - Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), ValueType.FP64, - targetOp, sourceOp, rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, - target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper()); + Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), + ValueType.FP64, targetOp, sourceOp, ixRange[0], ixRange[1], ixRange[2], ixRange[3], + target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper()); setIdentifierParams(leftIndexOp, target); - leftIndexOp.setParseInfo(target); leftIndexOp.setDim1(target.getOrigDim1()); leftIndexOp.setDim2(target.getOrigDim2()); @@ -1694,38 +1674,7 @@ public class DMLTranslator private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap<String, Hop> hops) { // process Hops for indexes (for source) - Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null; - - if (source.getRowLowerBound() != null) - rowLowerHops = processExpression(source.getRowLowerBound(),null,hops); - else - rowLowerHops = new LiteralOp(1); - - if (source.getRowUpperBound() != null) - rowUpperHops = processExpression(source.getRowUpperBound(),null,hops); - else - { - if ( source.getOrigDim1() != -1 ) - rowUpperHops = new LiteralOp(source.getOrigDim1()); - else { - rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(source.getName())); - rowUpperHops.setParseInfo(source); - } - } - if (source.getColLowerBound() != null) - colLowerHops = processExpression(source.getColLowerBound(),null,hops); - else - colLowerHops = new LiteralOp(1); - - if (source.getColUpperBound() != null) - colUpperHops = processExpression(source.getColUpperBound(),null,hops); - else - { - if ( source.getOrigDim2() != -1 ) - colUpperHops = new LiteralOp(source.getOrigDim2()); - else - colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(source.getName())); - } + Hop[] ixRange = getIndexingBounds(source, hops, false); if (target == null) { target = createTarget(source); @@ -1735,8 +1684,8 @@ public class DMLTranslator target.setNnz(-1); Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(), - hops.get(source.getName()), rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, - source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper()); + hops.get(source.getName()), ixRange[0], ixRange[1], ixRange[2], ixRange[3], + source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper()); indexOp.setParseInfo(target); setIdentifierParams(indexOp, target); @@ -1744,6 +1693,34 @@ public class DMLTranslator return indexOp; } + private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap<String, Hop> hops, boolean lix) { + Hop rowLowerHops = (ix.getRowLowerBound() != null) ? + processExpression(ix.getRowLowerBound(),null, hops) : new LiteralOp(1); + Hop colLowerHops = (ix.getColLowerBound() != null) ? + processExpression(ix.getColLowerBound(),null, hops) : new LiteralOp(1); + + Hop rowUpperHops = null, colUpperHops = null; + if (ix.getRowUpperBound() != null) + rowUpperHops = processExpression(ix.getRowUpperBound(),null,hops); + else { + rowUpperHops = ((lix ? ix.getDim1() : ix.getOrigDim1()) != -1) ? + new LiteralOp(ix.getOrigDim1()) : + new UnaryOp(ix.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(ix.getName())); + rowUpperHops.setParseInfo(ix); + } + + if (ix.getColUpperBound() != null) + colUpperHops = processExpression(ix.getColUpperBound(),null,hops); + else { + colUpperHops = ((lix ? ix.getDim2() : ix.getOrigDim2()) != -1) ? + new LiteralOp(ix.getOrigDim2()) : + new UnaryOp(ix.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(ix.getName())); + colUpperHops.setParseInfo(ix); + } + + return new Hop[] {rowLowerHops, rowUpperHops, colLowerHops, colUpperHops}; + } + /** * Construct Hops from parse tree : Process Binary Expression in an diff --git a/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java new file mode 100644 index 0000000..3db2535 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/indexing/IndexedAdditionAssignmentTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.indexing; + + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.lops.LopProperties.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; + +public class IndexedAdditionAssignmentTest extends AutomatedTestBase +{ + private final static String TEST_DIR = "functions/indexing/"; + private final static String TEST_NAME = "IndexedAdditionTest"; + + private final static String TEST_CLASS_DIR = TEST_DIR + IndexedAdditionAssignmentTest.class.getSimpleName() + "/"; + + private final static int rows = 1279; + private final static int cols = 1050; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"})); + } + + @Test + public void testIndexedAssignmentAddScalarCP() { + runIndexedAdditionAssignment(true, ExecType.CP); + } + + @Test + public void testIndexedAssignmentAddMatrixCP() { + runIndexedAdditionAssignment(false, ExecType.CP); + } + + @Test + public void testIndexedAssignmentAddScalarSpark() { + runIndexedAdditionAssignment(true, ExecType.SPARK); + } + + @Test + public void testIndexedAssignmentAddMatrixSpark() { + runIndexedAdditionAssignment(false, ExecType.SPARK); + } + + private void runIndexedAdditionAssignment(boolean scalar, ExecType instType) { + ExecMode platformOld = setExecMode(instType); + + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + //test is adding or subtracting 7 to area 1x1 or 10x10 + //of an initially constraint (3) matrix and sums it up + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-explain" , "-args", + Long.toString(rows), Long.toString(cols), + String.valueOf(scalar).toUpperCase(), output("A")}; + + runTest(true, false, null, -1); + + Double ret = readDMLMatrixFromOutputDir("A").get(new CellIndex(1,1)); + Assert.assertEquals(new Double(3*rows*cols + 7*(scalar?1:100)), ret); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java index b5ea0aa..68fbc37 100644 --- a/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/indexing/LeftIndexingScalarTest.java @@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.indexing; import java.util.HashMap; import org.junit.Test; -import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.lops.LopProperties.ExecType; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; @@ -33,7 +32,6 @@ import org.apache.sysds.test.TestUtils; public class LeftIndexingScalarTest extends AutomatedTestBase { - private final static String TEST_DIR = "functions/indexing/"; private final static String TEST_NAME = "LeftIndexingScalarTest"; private final static String TEST_CLASS_DIR = TEST_DIR + LeftIndexingScalarTest.class.getSimpleName() + "/"; @@ -52,31 +50,18 @@ public class LeftIndexingScalarTest extends AutomatedTestBase } @Test - public void testLeftIndexingScalarCP() - { + public void testLeftIndexingScalarCP() { runLeftIndexingTest(ExecType.CP); } @Test - public void testLeftIndexingScalarSP() - { + public void testLeftIndexingScalarSP() { runLeftIndexingTest(ExecType.SPARK); } private void runLeftIndexingTest( ExecType instType ) - { - //rtplatform for MR - ExecMode platformOld = rtplatform; - if(instType == ExecType.SPARK) { - rtplatform = ExecMode.SPARK; - } - else { - rtplatform = ExecMode.HYBRID; - } - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == ExecMode.SPARK ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - + { + ExecMode platformOld = setExecMode(instType); try { @@ -91,10 +76,10 @@ public class LeftIndexingScalarTest extends AutomatedTestBase fullRScriptName = HOME + TEST_NAME + ".R"; rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); - double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, System.currentTimeMillis()); - writeInputMatrix("A", A, true); - - runTest(true, false, null, -1); + double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, System.currentTimeMillis()); + writeInputMatrix("A", A, true); + + runTest(true, false, null, -1); runRScript(true); HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("A"); @@ -102,11 +87,8 @@ public class LeftIndexingScalarTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, epsilon, "A-DML", "A-R"); checkDMLMetaDataFile("A", new MatrixCharacteristics(rows,cols,1,1)); } - finally - { - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + finally { + resetExecMode(platformOld); } } } - diff --git a/src/test/scripts/functions/indexing/IndexedAdditionTest.dml b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml new file mode 100644 index 0000000..415a795 --- /dev/null +++ b/src/test/scripts/functions/indexing/IndexedAdditionTest.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = matrix(3, $1, $2); + +if( $3 ) + A[10,20] += 7; +else + A[10:19,20:29] += 7; + +R = as.matrix(sum(A)) +write(R, $4, format="text")