This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ca8d20916c [SYSTEMDS-3894] New out-of-core binary scalar-matrix 
operations
ca8d20916c is described below

commit ca8d20916c2f6a5073f0e2026511908f53bb0904
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Tue Jul 15 17:58:05 2025 +0200

    [SYSTEMDS-3894] New out-of-core binary scalar-matrix operations
    
    This patch completes the selected example operations for the new
    out-of-core backend and related test.
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  3 -
 .../runtime/instructions/OOCInstructionParser.java |  6 +-
 .../instructions/ooc/BinaryOOCInstruction.java     | 95 ++++++++++++++++++++++
 .../functions/ooc/SumScalarMultiplicationTest.java | 29 +++++--
 4 files changed, 121 insertions(+), 12 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index f433931a52..a3ddb45ea6 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -854,9 +854,6 @@ public class BinaryOp extends MultiThreadedHop {
                        _etype = ExecType.CP;
                }
                
-               if( _etype == ExecType.OOC ) //TODO
-                       setExecType(ExecType.CP);
-               
                //mark for recompile (forever)
                setRequiresRecompileIfNecessary();
                
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index c437684d3b..0e5b3f1f51 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.InstructionType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
 
@@ -50,10 +51,9 @@ public class OOCInstructionParser extends InstructionParser {
                                return 
ReblockOOCInstruction.parseInstruction(str);
                        case AggregateUnary:
                                return 
AggregateUnaryOOCInstruction.parseInstruction(str);
-                       
-                       // TODO:
                        case Binary:
-
+                               return 
BinaryOOCInstruction.parseInstruction(str);
+                       
                        default:
                                throw new DMLRuntimeException("Invalid OOC 
Instruction Type: " + ooctype);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
new file mode 100644
index 0000000000..fe76e60b9e
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.concurrent.ExecutorService;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class BinaryOOCInstruction extends ComputationOOCInstruction {
+       
+       protected BinaryOOCInstruction(OOCType type, Operator bop, 
+                       CPOperand in1, CPOperand in2, CPOperand out, String 
opcode, String istr) {
+               super(type, bop, in1, in2, out, opcode, istr);
+       }
+
+       public static BinaryOOCInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(parts, 3);
+               String opcode = parts[0];
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand out = new CPOperand(parts[3]);
+               Operator bop = 
InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
+               
+               return new BinaryOOCInstruction(
+                       OOCType.Binary, bop, in1, in2, out, opcode, str);
+       }
+       
+       @Override
+       public void processInstruction( ExecutionContext ec ) {
+               //TODO support all types, currently only binary matrix-scalar
+               
+               //get operator and scalar
+               CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) 
? input2 : input1;
+               ScalarObject constant = ec.getScalarInput(scalar);
+               ScalarOperator sc_op = 
((ScalarOperator)_optr).setConstant(constant.getDoubleValue());
+               
+               //create thread and process binary operation
+               MatrixObject min = ec.getMatrixObject(input1);
+               LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+               LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
+               ec.getMatrixObject(output).setStreamHandle(qOut);
+               
+               ExecutorService pool = CommonThreadPool.get();
+               try {
+                       pool.submit(() -> {
+                               IndexedMatrixValue tmp = null;
+                               try {
+                                       while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
+                                               IndexedMatrixValue tmpOut = new 
IndexedMatrixValue();
+                                               tmpOut.set(tmp.getIndexes(),
+                                                       
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
+                                               qOut.enqueueTask(tmpOut);
+                                       }
+                                       qOut.closeInput();
+                               }
+                               catch(Exception ex) {
+                                       throw new DMLRuntimeException(ex);
+                               }
+                       });
+               }
+               finally {
+                       pool.shutdown();
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
index 2272588bab..f0d9228a53 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
@@ -23,6 +23,7 @@ import org.apache.sysds.common.Opcodes;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.io.MatrixWriter;
 import org.apache.sysds.runtime.io.MatrixWriterFactory;
@@ -57,11 +58,26 @@ public class SumScalarMultiplicationTest extends 
AutomatedTestBase {
         * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
         */
        @Test
-       public void testSumScalarMult() {
-
+       public void testSumScalarMultNoRewrite() {
+               testSumScalarMult(false);
+       }
+       
+       /**
+        * Test the sum of scalar multiplication, "sum(X)*7", with OOC backend.
+        */
+       @Test
+       public void testSumScalarMultRewrite() {
+               testSumScalarMult(true);
+       }
+       
+       
+       public void testSumScalarMult(boolean rewrite)
+       {
                Types.ExecMode platformOld = rtplatform;
                rtplatform = Types.ExecMode.SINGLE_NODE;
-
+               boolean oldRewrite = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+               
                try {
                        getAndLoadTestConfiguration(TEST_NAME);
                        String HOME = SCRIPT_DIR + TEST_DIR;
@@ -92,16 +108,17 @@ public class SumScalarMultiplicationTest extends 
AutomatedTestBase {
                        String prefix = Instruction.OOC_INST_PREFIX;
                        Assert.assertTrue("OOC wasn't used for RBLK",
                                heavyHittersContainsString(prefix + 
Opcodes.RBLK));
+                       if(!rewrite)
+                               Assert.assertTrue("OOC wasn't used for SUM",
+                                       heavyHittersContainsString(prefix + 
Opcodes.MULT));
                        Assert.assertTrue("OOC wasn't used for SUM",
                                heavyHittersContainsString(prefix + 
Opcodes.UAKP));
-                       
-//                     boolean usedOOCMult = 
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
-//                     Assert.assertTrue("OOC wasn't used for MULT", 
usedOOCMult);
                }
                catch(Exception ex) {
                        Assert.fail(ex.getMessage());
                }
                finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewrite;
                        resetExecMode(platformOld);
                }
        }

Reply via email to