This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new ca8d20916c [SYSTEMDS-3894] New out-of-core binary scalar-matrix operations ca8d20916c is described below commit ca8d20916c2f6a5073f0e2026511908f53bb0904 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Tue Jul 15 17:58:05 2025 +0200 [SYSTEMDS-3894] New out-of-core binary scalar-matrix operations This patch completes the selected example operations for the new out-of-core backend and related test. --- src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 - .../runtime/instructions/OOCInstructionParser.java | 6 +- .../instructions/ooc/BinaryOOCInstruction.java | 95 ++++++++++++++++++++++ .../functions/ooc/SumScalarMultiplicationTest.java | 29 +++++-- 4 files changed, 121 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index f433931a52..a3ddb45ea6 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -854,9 +854,6 @@ public class BinaryOp extends MultiThreadedHop { _etype = ExecType.CP; } - if( _etype == ExecType.OOC ) //TODO - setExecType(ExecType.CP); - //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index c437684d3b..0e5b3f1f51 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.InstructionType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; @@ -50,10 +51,9 @@ public class OOCInstructionParser extends InstructionParser { return ReblockOOCInstruction.parseInstruction(str); case AggregateUnary: return AggregateUnaryOOCInstruction.parseInstruction(str); - - // TODO: case Binary: - + return BinaryOOCInstruction.parseInstruction(str); + default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java new file mode 100644 index 0000000000..fe76e60b9e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.concurrent.ExecutorService; + +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ScalarOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class BinaryOOCInstruction extends ComputationOOCInstruction { + + protected BinaryOOCInstruction(OOCType type, Operator bop, + CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { + super(type, bop, in1, in2, out, opcode, istr); + } + + public static BinaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 3); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + Operator bop = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2); + + return new BinaryOOCInstruction( + OOCType.Binary, bop, in1, in2, out, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + //TODO support all types, currently only binary matrix-scalar + + //get operator and scalar + CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1; + ScalarObject constant = ec.getScalarInput(scalar); + ScalarOperator sc_op = ((ScalarOperator)_optr).setConstant(constant.getDoubleValue()); + + //create thread and process binary operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle(); + LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); + qOut.enqueueTask(tmpOut); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + } + finally { + pool.shutdown(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java index 2272588bab..f0d9228a53 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java @@ -23,6 +23,7 @@ import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; @@ -57,11 +58,26 @@ public class SumScalarMultiplicationTest extends AutomatedTestBase { * Test the sum of scalar multiplication, "sum(X*7)", with OOC backend. */ @Test - public void testSumScalarMult() { - + public void testSumScalarMultNoRewrite() { + testSumScalarMult(false); + } + + /** + * Test the sum of scalar multiplication, "sum(X)*7", with OOC backend. + */ + @Test + public void testSumScalarMultRewrite() { + testSumScalarMult(true); + } + + + public void testSumScalarMult(boolean rewrite) + { Types.ExecMode platformOld = rtplatform; rtplatform = Types.ExecMode.SINGLE_NODE; - + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + try { getAndLoadTestConfiguration(TEST_NAME); String HOME = SCRIPT_DIR + TEST_DIR; @@ -92,16 +108,17 @@ public class SumScalarMultiplicationTest extends AutomatedTestBase { String prefix = Instruction.OOC_INST_PREFIX; Assert.assertTrue("OOC wasn't used for RBLK", heavyHittersContainsString(prefix + Opcodes.RBLK)); + if(!rewrite) + Assert.assertTrue("OOC wasn't used for SUM", + heavyHittersContainsString(prefix + Opcodes.MULT)); Assert.assertTrue("OOC wasn't used for SUM", heavyHittersContainsString(prefix + Opcodes.UAKP)); - -// boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT); -// Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult); } catch(Exception ex) { Assert.fail(ex.getMessage()); } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; resetExecMode(platformOld); } }