This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a6faf44254 [SYSTEMDS-3895] New out-of-core unary aggregate operations
a6faf44254 is described below
commit a6faf442547c29042bf86388da64472422e02dcc
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jul 13 14:18:30 2025 +0200
[SYSTEMDS-3895] New out-of-core unary aggregate operations
This patch introduces the out-of-core unary aggregate operations as an
example of how to implement operations against the input stream of
blocks.
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 3 -
.../runtime/instructions/OOCInstructionParser.java | 4 +-
.../ooc/AggregateUnaryOOCInstruction.java | 94 ++++++++++++++++++++++
.../functions/ooc/SumScalarMultiplicationTest.java | 4 +-
4 files changed, 99 insertions(+), 6 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 2f5cb53acf..b71b57aa18 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -409,9 +409,6 @@ public class AggUnaryOp extends MultiThreadedHop
else
setRequiresRecompileIfNecessary();
- if( _etype == ExecType.OOC ) //TODO
- setExecType(ExecType.CP);
-
return _etype;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index e0f84c5bd2..c437684d3b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -23,6 +23,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
@@ -47,9 +48,10 @@ public class OOCInstructionParser extends InstructionParser {
switch(ooctype) {
case Reblock:
return
ReblockOOCInstruction.parseInstruction(str);
+ case AggregateUnary:
+ return
AggregateUnaryOOCInstruction.parseInstruction(str);
// TODO:
- case AggregateUnary:
case Binary:
default:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
new file mode 100644
index 0000000000..c333088239
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Types.CorrectionLocationType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+
+
+public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
+ private AggregateOperator _aop = null;
+
+ protected AggregateUnaryOOCInstruction(OOCType type,
AggregateUnaryOperator auop, AggregateOperator aop,
+ CPOperand in, CPOperand out, String opcode, String
istr) {
+ super(type, auop, in, out, opcode, istr);
+ _aop = aop;
+ }
+
+ public static AggregateUnaryOOCInstruction parseInstruction(String str)
{
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 2);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+
+ String aopcode =
InstructionUtils.deriveAggregateOperatorOpcode(opcode);
+ CorrectionLocationType corrLoc =
InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode);
+ AggregateUnaryOperator aggun =
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+ AggregateOperator aop =
InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
+ return new AggregateUnaryOOCInstruction(
+ OOCType.AggregateUnary, aggun, aop, in1, out, opcode,
str);
+ }
+
+ @Override
+ public void processInstruction( ExecutionContext ec ) {
+ //TODO support all types of aggregations, currently only full
aggregation
+
+ //setup operators and input queue
+ AggregateUnaryOperator aggun = (AggregateUnaryOperator)
getOperator();
+ MatrixObject min = ec.getMatrixObject(input1);
+ LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
+ IndexedMatrixValue tmp = null;
+ int blen = ConfigurationManager.getBlocksize();
+
+ //read blocks and aggregate immediately into result
+ int extra = _aop.correction.getNumRemovedRowsColumns();
+ MatrixBlock ret = new MatrixBlock(1,1+extra,false);
+ MatrixBlock corr = new MatrixBlock(1,1+extra,false);
+ try {
+ while((tmp = q.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ //block aggregation
+ MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock)
tmp.getValue())
+ .aggregateUnaryOperations(aggun, new
MatrixBlock(), blen, tmp.getIndexes());
+ //accumulation into final result
+ OperationsOnMatrixValues.incrementalAggregation(
+ ret, _aop.existsCorrection() ? corr :
null, ltmp, _aop, true);
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //create scalar output
+ ec.setScalarOutput(output.getName(), new
DoubleObject(ret.get(0, 0)));
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
index dafc9c7bf6..2272588bab 100644
---
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
@@ -92,11 +92,11 @@ public class SumScalarMultiplicationTest extends
AutomatedTestBase {
String prefix = Instruction.OOC_INST_PREFIX;
Assert.assertTrue("OOC wasn't used for RBLK",
heavyHittersContainsString(prefix +
Opcodes.RBLK));
+ Assert.assertTrue("OOC wasn't used for SUM",
+ heavyHittersContainsString(prefix +
Opcodes.UAKP));
// boolean usedOOCMult =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
// Assert.assertTrue("OOC wasn't used for MULT",
usedOOCMult);
-// boolean usedOOCSum =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
-// Assert.assertTrue("OOC wasn't used for SUM",
usedOOCSum);
}
catch(Exception ex) {
Assert.fail(ex.getMessage());