This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a6faf44254 [SYSTEMDS-3895] New out-of-core unary aggregate operations
a6faf44254 is described below

commit a6faf442547c29042bf86388da64472422e02dcc
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sun Jul 13 14:18:30 2025 +0200

    [SYSTEMDS-3895] New out-of-core unary aggregate operations
    
    This patch introduces the out-of-core unary aggregate operations as an
    example of how to implement operations against the input stream of
    blocks.
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |  3 -
 .../runtime/instructions/OOCInstructionParser.java |  4 +-
 .../ooc/AggregateUnaryOOCInstruction.java          | 94 ++++++++++++++++++++++
 .../functions/ooc/SumScalarMultiplicationTest.java |  4 +-
 4 files changed, 99 insertions(+), 6 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 2f5cb53acf..b71b57aa18 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -409,9 +409,6 @@ public class AggUnaryOp extends MultiThreadedHop
                else
                        setRequiresRecompileIfNecessary();
                
-               if( _etype == ExecType.OOC ) //TODO
-                       setExecType(ExecType.CP);
-               
                return _etype;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index e0f84c5bd2..c437684d3b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -23,6 +23,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.InstructionType;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
 
@@ -47,9 +48,10 @@ public class OOCInstructionParser extends InstructionParser {
                switch(ooctype) {
                        case Reblock:
                                return 
ReblockOOCInstruction.parseInstruction(str);
+                       case AggregateUnary:
+                               return 
AggregateUnaryOOCInstruction.parseInstruction(str);
                        
                        // TODO:
-                       case AggregateUnary:
                        case Binary:
 
                        default:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
new file mode 100644
index 0000000000..c333088239
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Types.CorrectionLocationType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+
+
+public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
+       private AggregateOperator _aop = null;
+
+       protected AggregateUnaryOOCInstruction(OOCType type, 
AggregateUnaryOperator auop, AggregateOperator aop, 
+                       CPOperand in, CPOperand out, String opcode, String 
istr) {
+               super(type, auop, in, out, opcode, istr);
+               _aop = aop;
+       }
+
+       public static AggregateUnaryOOCInstruction parseInstruction(String str) 
{
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(parts, 2);
+               String opcode = parts[0];
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand out = new CPOperand(parts[2]);
+               
+               String aopcode = 
InstructionUtils.deriveAggregateOperatorOpcode(opcode);
+               CorrectionLocationType corrLoc = 
InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode);
+               AggregateUnaryOperator aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+               AggregateOperator aop = 
InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
+               return new AggregateUnaryOOCInstruction(
+                       OOCType.AggregateUnary, aggun, aop, in1, out, opcode, 
str);
+       }
+       
+       @Override
+       public void processInstruction( ExecutionContext ec ) {
+               //TODO support all types of aggregations, currently only full 
aggregation
+               
+               //setup operators and input queue
+               AggregateUnaryOperator aggun = (AggregateUnaryOperator) 
getOperator(); 
+               MatrixObject min = ec.getMatrixObject(input1);
+               LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
+               IndexedMatrixValue tmp = null;
+               int blen = ConfigurationManager.getBlocksize();
+               
+               //read blocks and aggregate immediately into result
+               int extra = _aop.correction.getNumRemovedRowsColumns();
+               MatrixBlock ret = new MatrixBlock(1,1+extra,false);
+               MatrixBlock corr = new MatrixBlock(1,1+extra,false);
+               try {
+                       while((tmp = q.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
+                               //block aggregation
+                               MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) 
tmp.getValue())
+                                       .aggregateUnaryOperations(aggun, new 
MatrixBlock(), blen, tmp.getIndexes());
+                               //accumulation into final result
+                               OperationsOnMatrixValues.incrementalAggregation(
+                                       ret, _aop.existsCorrection() ? corr : 
null, ltmp, _aop, true);
+                       }
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               
+               //create scalar output
+               ec.setScalarOutput(output.getName(), new 
DoubleObject(ret.get(0, 0)));
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
index dafc9c7bf6..2272588bab 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
@@ -92,11 +92,11 @@ public class SumScalarMultiplicationTest extends 
AutomatedTestBase {
                        String prefix = Instruction.OOC_INST_PREFIX;
                        Assert.assertTrue("OOC wasn't used for RBLK",
                                heavyHittersContainsString(prefix + 
Opcodes.RBLK));
+                       Assert.assertTrue("OOC wasn't used for SUM",
+                               heavyHittersContainsString(prefix + 
Opcodes.UAKP));
                        
 //                     boolean usedOOCMult = 
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
 //                     Assert.assertTrue("OOC wasn't used for MULT", 
usedOOCMult);
-//                     boolean usedOOCSum = 
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
-//                     Assert.assertTrue("OOC wasn't used for SUM", 
usedOOCSum);
                }
                catch(Exception ex) {
                        Assert.fail(ex.getMessage());

Reply via email to