This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a30b5b0274 [SYSTEMDS-3525] Disable and add tests for binary inplace 
operations
a30b5b0274 is described below

commit a30b5b02747d98b1f513129d73a65a466a20261d
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Apr 27 14:45:35 2023 +0200

    [SYSTEMDS-3525] Disable and add tests for binary inplace operations
    
    This patch adds a flag for the update inplace for binary operations.
    The flag is disabled by default. This patch also adds a test which
    exposes a bug of binary inplace. A binary inplace operation consuming
    another inplaced intermediate (e.g. right index) leads to corruption.
    We also disable binary inplace if lineage-based reuse is enabled to
    avoid corrupting the cached intermediates.
    
    Closes #1814
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  5 +-
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  9 ++-
 .../cp/BinaryMatrixMatrixCPInstruction.java        |  4 +
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  3 +
 .../updateinplace/BinaryUpdateInPlaceTest.java     | 87 ++++++++++++++++++++++
 .../updateinplace/BinaryUpdateInplace.dml          | 66 ++++++++++++++++
 6 files changed, 172 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 04585d7dc4..74740f10ce 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -51,6 +51,7 @@ import org.apache.sysds.lops.PickByCount;
 import org.apache.sysds.lops.SortKeys;
 import org.apache.sysds.lops.Unary;
 import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
@@ -814,7 +815,8 @@ public class BinaryOp extends MultiThreadedHop {
                        _etype = ExecType.SPARK;
                }
 
-               if( transitive && _etypeForced != ExecType.SPARK && 
_etypeForced != ExecType.FED && //
+               if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE &&
+                       transitive && _etypeForced != ExecType.SPARK && 
_etypeForced != ExecType.FED &&
                        getDataType().isMatrix() // Output is a matrix
                        && op == OpOp2.DIV // Operation is division
                        && dt1.isMatrix() // Left hand side is a Matrix
@@ -823,6 +825,7 @@ public class BinaryOp extends MultiThreadedHop {
                        && memOfInputIsLessThanBudget() //
                        && getInput().get(0).getExecType() != ExecType.SPARK // 
Is not already a spark operation
                        && doesNotContainNanAndInf(getInput().get(1)) // 
Guaranteed not to densify the operation
+                       && LineageCacheConfig.ReuseCacheType.isNone() // 
Inplace update corrupts the already cached input matrix block
                ) {
                        inplace = true;
                        _etype = ExecType.CP;
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 580ddccf20..4ab7c33dbd 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -215,7 +215,14 @@ public class OptimizerUtils
         */
        //TODO enabling it by default requires modifications in lineage-based 
reuse
        public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
-       
+
+       /**
+        * Enables update-in-place for binary operators if the first input
+        * has no consumers. In this case we directly write the output
+        * values back to the first input block.
+        */
+       public static boolean ALLOW_BINARY_UPDATE_IN_PLACE = false;
+
        /**
         * Replace eval second-order function calls with normal function call
         * if the function name is a known string (after constant propagation).
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 20119ceacd..cff0650235 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -53,6 +53,10 @@ public class BinaryMatrixMatrixCPInstruction extends 
BinaryCPInstruction {
                        inplace = false;
        }
 
+       public boolean isInPlace() {
+               return inplace;
+       }
+
        @Override
        public void processInstruction(ExecutionContext ec) {
                // Read input matrices
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index a483b6c21b..0ce6cf3a8e 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -27,6 +27,7 @@ import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
+import 
org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
@@ -228,6 +229,8 @@ public class LineageCacheConfig
                        || (inst instanceof DataGenCPInstruction) && 
((DataGenCPInstruction) inst).isMatrixCall());
                boolean updateInplace = (inst instanceof 
MatrixIndexingCPInstruction)
                        && 
ec.getMatrixObject(((ComputationCPInstruction)inst).input1).getUpdateType().isInPlace();
+               updateInplace = updateInplace || ((inst instanceof 
BinaryMatrixMatrixCPInstruction)
+                       && ((BinaryMatrixMatrixCPInstruction) 
inst).isInPlace());
                boolean federatedOutput = false;
                return insttype && rightop && !updateInplace && 
!federatedOutput;
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
 
b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
new file mode 100644
index 0000000000..19bb2cdefa
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+       import java.util.ArrayList;
+       import java.util.HashMap;
+       import java.util.List;
+
+       import org.apache.sysds.common.Types;
+       import org.apache.sysds.hops.OptimizerUtils;
+       import org.apache.sysds.runtime.matrix.data.MatrixValue;
+       import org.apache.sysds.test.AutomatedTestBase;
+       import org.apache.sysds.test.TestConfiguration;
+       import org.apache.sysds.test.TestUtils;
+       import org.junit.Ignore;
+       import org.junit.Test;
+
+
+public class BinaryUpdateInPlaceTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "BinaryUpdateInplace";
+       private final static String TEST_DIR = "functions/updateinplace/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BinaryUpdateInPlaceTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-3;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+       }
+
+       @Ignore
+       @Test
+       public void testInPlace() {
+               runInPlaceTest(Types.ExecType.CP);
+       }
+
+
+       private void runInPlaceTest(Types.ExecType instType) {
+               Types.ExecMode platformOld = setExecMode(instType);
+               boolean oldFlag = OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE;
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       List<String> proArgs = new ArrayList<>();
+                       proArgs.add("-args");
+                       proArgs.add(output("R"));
+                       programArgs = proArgs.toArray(new 
String[proArgs.size()]);
+
+                       OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = true;
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> R_inplace = 
readDMLMatrixFromOutputDir("R");
+                       OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = false;
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> R = 
readDMLMatrixFromOutputDir("R");
+
+                       //compare matrices
+                       
TestUtils.compareMatrices(R_inplace,R,eps,"with-Inplace","no_Inplace");
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+               }
+               finally {
+                       rtplatform = platformOld;
+                       OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = oldFlag;
+               }
+       }
+}
+
diff --git a/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml 
b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
new file mode 100644
index 0000000000..8283d0eec8
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+D = rand(rows=32, cols=100, min=0, max=20, seed=42)
+bs = 32;
+ep = 3;
+iter_ep = ceil(nrow(D)/bs);
+maxiter = ep * iter_ep;
+beg = 1;
+iter = 0;
+i = 1;
+R = matrix(0, rows=1, cols=maxiter+1);
+
+while (iter < maxiter) {
+  end = beg + bs - 1;
+  if (end>nrow(D))
+    end = nrow(D);
+  X = D[beg:end,]
+
+  #inlace binary after inplace indexing corrupts the dataset
+  R[1,iter+1] = sum(D);
+
+  #reusable OP across epochs
+  X = scale(X, FALSE, TRUE);
+  #pollute cache with not reusable OPs
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+
+  iter = iter + 1;
+  if (end == nrow(D))
+    beg = 1;
+  else
+    beg = end + 1;
+  i = i + 1;
+
+}
+#R = X;
+R[1,maxiter+1] = sum(X);
+write(R, $1, format="text");
+

Reply via email to