This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new a6faf44254 [SYSTEMDS-3895] New out-of-core unary aggregate operations a6faf44254 is described below commit a6faf442547c29042bf86388da64472422e02dcc Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sun Jul 13 14:18:30 2025 +0200 [SYSTEMDS-3895] New out-of-core unary aggregate operations This patch introduces the out-of-core unary aggregate operations as an example of how to implement operations against the input stream of blocks. --- .../java/org/apache/sysds/hops/AggUnaryOp.java | 3 - .../runtime/instructions/OOCInstructionParser.java | 4 +- .../ooc/AggregateUnaryOOCInstruction.java | 94 ++++++++++++++++++++++ .../functions/ooc/SumScalarMultiplicationTest.java | 4 +- 4 files changed, 99 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index 2f5cb53acf..b71b57aa18 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -409,9 +409,6 @@ public class AggUnaryOp extends MultiThreadedHop else setRequiresRecompileIfNecessary(); - if( _etype == ExecType.OOC ) //TODO - setExecType(ExecType.CP); - return _etype; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index e0f84c5bd2..c437684d3b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.InstructionType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; @@ -47,9 +48,10 @@ public class OOCInstructionParser extends InstructionParser { switch(ooctype) { case Reblock: return ReblockOOCInstruction.parseInstruction(str); + case AggregateUnary: + return AggregateUnaryOOCInstruction.parseInstruction(str); // TODO: - case AggregateUnary: case Binary: default: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java new file mode 100644 index 0000000000..c333088239 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + + +public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction { + private AggregateOperator _aop = null; + + protected AggregateUnaryOOCInstruction(OOCType type, AggregateUnaryOperator auop, AggregateOperator aop, + CPOperand in, CPOperand out, String opcode, String istr) { + super(type, auop, in, out, opcode, istr); + _aop = aop; + } + + public static AggregateUnaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 2); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + + String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(opcode); + CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode); + AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode); + AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString()); + return new AggregateUnaryOOCInstruction( + OOCType.AggregateUnary, aggun, aop, in1, out, opcode, str); + } + + @Override + public void processInstruction( ExecutionContext ec ) { + //TODO support all types of aggregations, currently only full aggregation + + //setup operators and input queue + AggregateUnaryOperator aggun = (AggregateUnaryOperator) getOperator(); + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle(); + IndexedMatrixValue tmp = null; + int blen = ConfigurationManager.getBlocksize(); + + //read blocks and aggregate immediately into result + int extra = _aop.correction.getNumRemovedRowsColumns(); + MatrixBlock ret = new MatrixBlock(1,1+extra,false); + MatrixBlock corr = new MatrixBlock(1,1+extra,false); + try { + while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + //block aggregation + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) + .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); + //accumulation into final result + OperationsOnMatrixValues.incrementalAggregation( + ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true); + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + + //create scalar output + ec.setScalarOutput(output.getName(), new DoubleObject(ret.get(0, 0))); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java index dafc9c7bf6..2272588bab 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java @@ -92,11 +92,11 @@ public class SumScalarMultiplicationTest extends AutomatedTestBase { String prefix = Instruction.OOC_INST_PREFIX; Assert.assertTrue("OOC wasn't used for RBLK", heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for SUM", + heavyHittersContainsString(prefix + Opcodes.UAKP)); // boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT); // Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult); -// boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP); -// Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum); } catch(Exception ex) { Assert.fail(ex.getMessage());