This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new bc3216ad3c [SYSTEMDS-3927] Out-of-core centralMoment operations
bc3216ad3c is described below

commit bc3216ad3cf675a33cf28ef4c49a87b80ffcc402
Author: Jannik Lindemann <[email protected]>
AuthorDate: Fri Oct 24 08:59:28 2025 +0200

    [SYSTEMDS-3927] Out-of-core centralMoment operations
    
    Closes #2339.
---
 .../java/org/apache/sysds/lops/CentralMoment.java  |   2 +-
 .../runtime/instructions/OOCInstructionParser.java |   3 +
 .../ooc/AggregateUnaryOOCInstruction.java          |   7 +
 .../ooc/CentralMomentOOCInstruction.java           | 166 +++++++++++++++++++++
 .../ooc/ComputationOOCInstruction.java             |   8 +
 .../runtime/instructions/ooc/OOCInstruction.java   |   2 +-
 .../test/functions/ooc/CentralMomentTest.java      | 141 +++++++++++++++++
 .../functions/ooc/CentralMomentWeightsTest.java    | 147 ++++++++++++++++++
 src/test/scripts/functions/ooc/CentralMoment.dml   |  25 ++++
 .../scripts/functions/ooc/CentralMomentWeights.dml |  26 ++++
 10 files changed, 525 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java 
b/src/main/java/org/apache/sysds/lops/CentralMoment.java
index f2048f7e5c..b8907fd79f 100644
--- a/src/main/java/org/apache/sysds/lops/CentralMoment.java
+++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java
@@ -97,7 +97,7 @@ public class CentralMoment extends Lop
                                
getInputs().get(2).prepScalarInputOperand(getExecType()),
                                prepOutputOperand(output));
                }
-               if( getExecType() == ExecType.CP || getExecType() == 
ExecType.FED ) {
+               if(getExecType() == ExecType.CP || getExecType() == 
ExecType.FED || getExecType() == ExecType.OOC) {
                        sb.append(OPERAND_DELIMITOR);
                        sb.append(_numThreads);
                        if ( getExecType() == ExecType.FED ){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index c98301dcc6..9c0f0f2e0f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.InstructionType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
@@ -69,6 +70,8 @@ public class OOCInstructionParser extends InstructionParser {
                                return 
TransposeOOCInstruction.parseInstruction(str);
                        case Tee:
                                return TeeOOCInstruction.parseInstruction(str);
+            case CentralMoment:
+                return  CentralMomentOOCInstruction.parseInstruction(str);
                        
                        default:
                                throw new DMLRuntimeException("Invalid OOC 
Instruction Type: " + ooctype);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index b71cdaaeb5..c01fb3fa37 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -34,6 +34,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
@@ -49,6 +50,12 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                _aop = aop;
        }
 
+       protected AggregateUnaryOOCInstruction(OOCType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand in3,
+               CPOperand out, String opcode, String istr) {
+               super(type, op, in1, in2, in3, out, opcode, istr);
+               _aop = null;
+       }
+
        public static AggregateUnaryOOCInstruction parseInstruction(String str) 
{
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                InstructionUtils.checkNumFields(parts, 2);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
new file mode 100644
index 0000000000..9c122662c2
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction {
+
+       private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
+               String opcode, String str) {
+               super(OOCType.CM, cm, in1, in2, in3, out, opcode, str);
+       }
+
+       public static CentralMomentOOCInstruction parseInstruction(String str) {
+               CentralMomentCPInstruction cpInst = 
CentralMomentCPInstruction.parseInstruction(str);
+               return parseInstruction(cpInst);
+       }
+
+       public static CentralMomentOOCInstruction 
parseInstruction(CentralMomentCPInstruction inst) {
+               return new CentralMomentOOCInstruction((CMOperator) 
inst.getOperator(), inst.input1, inst.input2, inst.input3,
+                       inst.output, inst.getOpcode(), 
inst.getInstructionString());
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               String output_name = output.getName();
+
+               /*
+                * The "order" of the central moment in the instruction can
+                * be set to INVALID when the exact value is unknown at
+                * compilation time. We first need to determine the exact
+                * order and update the CMOperator, if needed.
+                */
+
+               MatrixObject matObj = ec.getMatrixObject(input1.getName());
+               LocalTaskQueue<IndexedMatrixValue> qIn = 
matObj.getStreamHandle();
+
+               CPOperand scalarInput = (input3 == null ? input2 : input3);
+               ScalarObject order = ec.getScalarInput(scalarInput);
+
+               CMOperator cm_op = ((CMOperator) _optr);
+               if(cm_op.getAggOpType() == 
CMOperator.AggregateOperationTypes.INVALID)
+                       cm_op = cm_op.setCMAggOp((int) order.getLongValue());
+
+               CMOperator finalCm_op = cm_op;
+
+               List<CM_COV_Object> cmObjs = new ArrayList<>();
+
+               if(input3 == null) {
+                       try {
+                               IndexedMatrixValue tmp;
+
+                               while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
+                                       // We only handle MatrixBlock, other 
types of MatrixValue will fail here
+                                       cmObjs.add(((MatrixBlock) 
tmp.getValue()).cmOperations(cm_op));
+                               }
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+               }
+               else {
+                       // Here we use a hash join approach
+                       // Note that this may keep blocks in the cache for a 
while, depending on when a matching block arrives in the stream
+                       MatrixObject wtObj = 
ec.getMatrixObject(input2.getName());
+
+                       DataCharacteristics dc = 
ec.getDataCharacteristics(input1.getName());
+                       DataCharacteristics dcW = 
ec.getDataCharacteristics(input2.getName());
+
+                       if (dc.getBlocksize() != dcW.getBlocksize())
+                               throw new DMLRuntimeException("Different block 
sizes are not yet supported");
+
+                       LocalTaskQueue<IndexedMatrixValue> wIn = 
wtObj.getStreamHandle();
+
+                       try {
+                               IndexedMatrixValue tmp = qIn.dequeueTask();
+                               IndexedMatrixValue tmpW = wIn.dequeueTask();
+                               Map<MatrixIndexes, MatrixValue> left = new 
HashMap<>();
+                               Map<MatrixIndexes, MatrixValue> right = new 
HashMap<>();
+
+                               boolean cont = tmp != 
LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS;
+
+                               while(cont) {
+                                       cont = false;
+
+                                       if(tmp != LocalTaskQueue.NO_MORE_TASKS) 
{
+                                               MatrixValue weights = 
right.remove(tmp.getIndexes());
+
+                                               if(weights != null)
+                                                       
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock) 
weights));
+                                               else
+                                                       
left.put(tmp.getIndexes(), tmp.getValue());
+
+                                               tmp = qIn.dequeueTask();
+                                               cont = tmp != 
LocalTaskQueue.NO_MORE_TASKS;
+                                       }
+
+                                       if(tmpW != 
LocalTaskQueue.NO_MORE_TASKS) {
+                                               MatrixValue q = 
left.remove(tmpW.getIndexes());
+
+                                               if(q != null)
+                                                       
cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock) 
tmpW.getValue()));
+                                               else
+                                                       
right.put(tmpW.getIndexes(), tmpW.getValue());
+
+                                               tmpW = wIn.dequeueTask();
+                                               cont |= tmpW != 
LocalTaskQueue.NO_MORE_TASKS;
+                                       }
+                               }
+
+                               if (!left.isEmpty() || !right.isEmpty())
+                                       throw new 
DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" + 
right.size());
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+               }
+
+               Optional<CM_COV_Object> res = cmObjs.stream()
+                       .reduce((arg0, arg1) -> (CM_COV_Object) 
finalCm_op.fn.execute(arg0, arg1));
+
+               try {
+                       ec.setScalarOutput(output_name, new 
DoubleObject(res.get().getRequiredResult(finalCm_op)));
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
index 5552017493..4dcdffcb0d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
@@ -42,6 +42,14 @@ public abstract class ComputationOOCInstruction extends 
OOCInstruction {
                output = out;
        }
 
+       protected ComputationOOCInstruction(OOCType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, 
String istr) {
+               super(type, op, opcode, istr);
+               input1 = in1;
+               input2 = in2;
+               input3 = in3;
+               output = out;
+       }
+
        public String getOutputVariableName() {
                return output.getName();
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index ff9046d490..5b1c766661 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
        protected static final Log LOG = 
LogFactory.getLog(OOCInstruction.class.getName());
 
        public enum OOCType {
-               Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, 
MAPMM, MMTSJ, Reorg, 
+               Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, 
MAPMM, MMTSJ, Reorg, CM
        }
 
        protected final OOCInstruction.OOCType _ooctype;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java
new file mode 100644
index 0000000000..79f05421ad
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+public class CentralMomentTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "CentralMoment";
+       private final static String TEST_DIR = "functions/ooc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CentralMomentTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-8;
+       private static final String INPUT_NAME = "X";
+       private static final String OUTPUT_NAME = "res";
+
+       private final static int rows = 1871;
+       private final static int maxVal = 7;
+       private final static double sparsity1 = 0.65;
+       private final static double sparsity2 = 0.05;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+               addTestConfiguration(TEST_NAME1, config);
+       }
+
+       @Test
+       public void testCentralMoment2Dense() {
+               runCentralMomentTest(2, false);
+       }
+
+       @Test
+       public void testCentralMoment3Dense() {
+               runCentralMomentTest(3, false);
+       }
+
+       @Test
+       public void testCentralMoment4Dense() {
+               runCentralMomentTest(4, false);
+       }
+
+       @Test
+       public void testCentralMoment2Sparse() {
+               runCentralMomentTest(2, true);
+       }
+
+       @Test
+       public void testCentralMoment3Sparse() {
+               runCentralMomentTest(3, true);
+       }
+
+       @Test
+       public void testCentralMoment4Sparse() {
+               runCentralMomentTest(4, true);
+       }
+
+       private void runCentralMomentTest(int order, boolean sparse) {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME1);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc", "-args", input(INPUT_NAME),
+                               Integer.toString(order), output(OUTPUT_NAME)};
+
+                       // 1. Generate the data in-memory as MatrixBlock objects
+                       double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, 
sparse ? sparsity2 : sparsity1, 7);
+
+                       // 2. Convert the double arrays to MatrixBlock objects
+                       MatrixBlock A_mb = 
DataConverter.convertToMatrixBlock(A_data);
+
+                       // 3. Create a binary matrix writer
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+                       // 4. Write matrix A to a binary SequenceFile
+                       writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 
1, 1000, A_mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), 
Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, 1, 1000, 
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+                       runTest(true, false, null, -1);
+
+                       //check Central Moment OOC
+                       Assert.assertTrue("OOC wasn't used for CentralMoment",
+                               
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM));
+
+                       //compare results
+
+                       // rerun without ooc flag
+                       programArgs = new String[] {"-explain", "-stats", 
"-args", input(INPUT_NAME), Integer.toString(order),
+                               output(OUTPUT_NAME + "_target")};
+                       runTest(true, false, null, -1);
+
+                       // compare matrices
+                       HashMap<MatrixValue.CellIndex, Double> ret1 = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+                       HashMap<MatrixValue.CellIndex, Double> ret2 = 
readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target");
+                       TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", 
"Ret-2");
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java
new file mode 100644
index 0000000000..994f84526f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+public class CentralMomentWeightsTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "CentralMomentWeights";
+       private final static String TEST_DIR = "functions/ooc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CentralMomentWeightsTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-8;
+       private static final String INPUT_NAME = "X";
+       private static final String INPUT_NAME_W = "W";
+       private static final String OUTPUT_NAME = "res";
+
+       private final static int rows = 1871;
+       private final static int maxVal = 7;
+       private final static double sparsity1 = 0.65;
+       private final static double sparsity2 = 0.05;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+               addTestConfiguration(TEST_NAME1, config);
+       }
+
+       @Test
+       public void testCentralMoment2Dense() {
+               runCentralMomentTest(2, false);
+       }
+
+       @Test
+       public void testCentralMoment3Dense() {
+               runCentralMomentTest(3, false);
+       }
+
+       @Test
+       public void testCentralMoment4Dense() {
+               runCentralMomentTest(4, false);
+       }
+
+       @Test
+       public void testCentralMoment2Sparse() {
+               runCentralMomentTest(2, true);
+       }
+
+       @Test
+       public void testCentralMoment3Sparse() {
+               runCentralMomentTest(3, true);
+       }
+
+       @Test
+       public void testCentralMoment4Sparse() {
+               runCentralMomentTest(4, true);
+       }
+
+       private void runCentralMomentTest(int order, boolean sparse) {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME1);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_W),
+                               Integer.toString(order), output(OUTPUT_NAME)};
+
+                       // 1. Generate the data in-memory as MatrixBlock objects
+                       double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, 
sparse ? sparsity2 : sparsity1, 7);
+                       double[][] W_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 
7);
+
+                       // 2. Convert the double arrays to MatrixBlock objects
+                       MatrixBlock A_mb = 
DataConverter.convertToMatrixBlock(A_data);
+                       MatrixBlock W_mb = 
DataConverter.convertToMatrixBlock(W_data);
+
+                       // 3. Create a binary matrix writer
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+                       // 4. Write matrix A to a binary SequenceFile
+                       writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 
1, 1000, A_mb.getNonZeros());
+                       writer.writeMatrixToHDFS(W_mb, input(INPUT_NAME_W), 
rows, 1, 1000, W_mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), 
Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, 1, 1000, 
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_W + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, 1, 1000, 
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+                       runTest(true, false, null, -1);
+
+                       //check tsmm OOC
+                       Assert.assertTrue("OOC wasn't used for CentralMoment",
+                               
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM));
+
+                       //compare results
+
+                       // rerun without ooc flag
+                       programArgs = new String[] {"-explain", "-stats", 
"-args", input(INPUT_NAME), input(INPUT_NAME_W),
+                               Integer.toString(order), output(OUTPUT_NAME + 
"_target")};
+                       runTest(true, false, null, -1);
+
+                       // compare matrices
+                       HashMap<MatrixValue.CellIndex, Double> ret1 = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+                       HashMap<MatrixValue.CellIndex, Double> ret2 = 
readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target");
+                       TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", 
"Ret-2");
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/ooc/CentralMoment.dml 
b/src/test/scripts/functions/ooc/CentralMoment.dml
new file mode 100644
index 0000000000..d5fda5a6a9
--- /dev/null
+++ b/src/test/scripts/functions/ooc/CentralMoment.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);            
+s = moment(A, $2);
+m = as.matrix(s);
+
+write(m, $3, format="text");
\ No newline at end of file
diff --git a/src/test/scripts/functions/ooc/CentralMomentWeights.dml 
b/src/test/scripts/functions/ooc/CentralMomentWeights.dml
new file mode 100644
index 0000000000..a8f24d4568
--- /dev/null
+++ b/src/test/scripts/functions/ooc/CentralMomentWeights.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+W = read($2);
+s = moment(A, W, $3);
+m = as.matrix(s);
+
+write(m, $4, format="text");
\ No newline at end of file

Reply via email to