This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 734c569f52 [SYSTEMDS-2316] Fix data/value type propagation on
add-assign operators
734c569f52 is described below
commit 734c569f5233698228d13b2560db0b9ff0547137
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jul 2 20:48:30 2023 +0200
[SYSTEMDS-2316] Fix data/value type propagation on add-assign operators
This patch fixes the data and value type propagation of mixed-type
add-assign operators during parsing. So far, X+=i only worked
correctly for matrix-matrix and scalar-scalar, but not matrix-scalar
operations. The runtime was able to tolerate incorrect values types
for scalar-scalar, which we now also correctly propagate.
---
.../org/apache/sysds/parser/StatementBlock.java | 5 +-
.../functions/misc/AdditionAssignmentTest.java | 68 ++++++++++++++++++++++
src/test/scripts/functions/misc/AddAssign.dml | 36 ++++++++++++
3 files changed, 108 insertions(+), 1 deletion(-)
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 3deb6a8001..315f15e51a 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -974,7 +974,10 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
}
// CASE: target NOT indexed identifier
else if (!(target instanceof IndexedIdentifier)){
- target.setProperties(source.getOutput());
+ if( as.isAccumulator() &&
ids.containsVariable(target.getName()) )
+
target.setProperties(ids.getVariable(target.getName()));
+ else
+ target.setProperties(source.getOutput());
if (source.getOutput() instanceof IndexedIdentifier)
target.setDimensions(source.getOutput().getDim1(),
source.getOutput().getDim2());
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/misc/AdditionAssignmentTest.java
b/src/test/java/org/apache/sysds/test/functions/misc/AdditionAssignmentTest.java
new file mode 100644
index 0000000000..e4680e0fff
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/misc/AdditionAssignmentTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.misc;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class AdditionAssignmentTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "AddAssign";
+
+ private final static String TEST_DIR = "functions/misc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
AdditionAssignmentTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+ }
+
+ @Test
+ public void testMatrixScalarAddition() {
+ runExistsTest(TEST_NAME1, false, 5500);
+ }
+
+ @Test
+ public void testScalarScalarAddition() {
+ runExistsTest(TEST_NAME1, true, 55);
+ }
+
+ private void runExistsTest(String testName, boolean scalarTarget,
double expected) {
+ TestConfiguration config = getTestConfiguration(testName);
+ loadTestConfiguration(config);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[]{"-args",
+ String.valueOf(scalarTarget).toUpperCase(), output("R")
};
+
+ //run script and compare output
+ runTest(true, false, null, -1);
+
+ //compare results
+ Double val = readDMLScalarFromOutputDir("R").get(new
CellIndex(1,1));
+ Assert.assertEquals(expected, val, 1e-12);
+ }
+}
diff --git a/src/test/scripts/functions/misc/AddAssign.dml
b/src/test/scripts/functions/misc/AddAssign.dml
new file mode 100644
index 0000000000..c32f61d273
--- /dev/null
+++ b/src/test/scripts/functions/misc/AddAssign.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if( $1 ) {
+ # scalar-scalar addition
+ X = 0;
+ for(i in 1:10)
+ X += i;
+ write(X, $2) #55
+}
+else {
+ # matrix-scalar addition
+ Y = matrix(0, 10, 10);
+ for(i in 1:10)
+ Y += i;
+ sY = sum(Y) #TODO in write
+ write(sY, $2) #5500
+}