This is an automated email from the ASF dual-hosted git repository.
janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3de7cbe7e9 Add OOC WDivMM
3de7cbe7e9 is described below
commit 3de7cbe7e9e1af46510d059629e55624cd32626b
Author: Jessica Priebe <[email protected]>
AuthorDate: Wed May 13 10:00:04 2026 +0200
Add OOC WDivMM
Closes #2464.
---
scripts/builtin/pnmf.dml | 6 +-
.../java/org/apache/sysds/hops/QuaternaryOp.java | 16 ++
.../runtime/instructions/OOCInstructionParser.java | 3 +
.../ooc/ComputationOOCInstruction.java | 11 +-
.../runtime/instructions/ooc/OOCInstruction.java | 2 +-
.../instructions/ooc/QuaternaryOOCInstruction.java | 54 +++++
.../instructions/ooc/WDivMMOOCInstruction.java | 218 +++++++++++++++++++++
.../apache/sysds/test/functions/ooc/PNMFTest.java | 18 +-
.../sysds/test/functions/ooc/WDivMMTest.java | 156 +++++++++++++++
src/test/scripts/functions/ooc/PNMF.dml | 6 +-
.../{PNMF.dml => WeightedDivMM4MultMinusLeft.dml} | 18 +-
.../{PNMF.dml => WeightedDivMM4MultMinusRight.dml} | 18 +-
.../ooc/{PNMF.dml => WeightedDivMMLeft.dml} | 16 +-
.../ooc/{PNMF.dml => WeightedDivMMLeftEps.dml} | 18 +-
.../ooc/{PNMF.dml => WeightedDivMMMultBasic.dml} | 16 +-
.../ooc/{PNMF.dml => WeightedDivMMMultLeft.dml} | 16 +-
.../{PNMF.dml => WeightedDivMMMultMinusLeft.dml} | 16 +-
.../{PNMF.dml => WeightedDivMMMultMinusRight.dml} | 16 +-
.../ooc/{PNMF.dml => WeightedDivMMMultRight.dml} | 16 +-
.../ooc/{PNMF.dml => WeightedDivMMRight.dml} | 16 +-
.../ooc/{PNMF.dml => WeightedDivMMRightEps.dml} | 18 +-
21 files changed, 595 insertions(+), 79 deletions(-)
diff --git a/scripts/builtin/pnmf.dml b/scripts/builtin/pnmf.dml
index 721ab7232b..bffc373592 100644
--- a/scripts/builtin/pnmf.dml
+++ b/scripts/builtin/pnmf.dml
@@ -42,12 +42,12 @@
# H List of amplitude matrices, one for each repetition.
#
------------------------------------------------------------------------------------
-m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer
maxi = 10, Boolean verbose=TRUE)
+m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer
maxi = 10, Boolean verbose=TRUE, Integer seed=-1)
return (Matrix[Double] W, Matrix[Double] H)
{
#initialize W and H
- W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025);
- H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025);
+ W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed);
+ H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed);
i = 0;
while(i < maxi) {
diff --git a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
index c2be949f37..8fede5f090 100644
--- a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
@@ -211,6 +211,8 @@ public class QuaternaryOp extends MultiThreadedHop
constructCPLopsWeightedDivMM(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedDivMM(wtype);
+ else if( et == ExecType.OOC )
+
constructOOCLopsWeightedDivMM(wtype);
else
throw new
HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
break;
@@ -462,6 +464,20 @@ public class QuaternaryOp extends MultiThreadedHop
}
}
+ private void constructOOCLopsWeightedDivMM(WDivMMType wtype)
+ {
+ WeightedDivMM wdiv = new WeightedDivMM(
+ getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(),
+ getInput().get(2).constructLops(),
+ getInput().get(3).constructLops(),
+ getDataType(), getValueType(), wtype, ExecType.OOC);
+
+ setOutputDimensions(wdiv);
+ setLineNumbers(wdiv);
+ setLops(wdiv);
+ }
+
private void constructCPLopsWeightedCeMM(WCeMMType wtype)
{
WeightedCrossEntropy wcemm = new WeightedCrossEntropy(
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index affda5910d..ae41639687 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -43,6 +43,7 @@ import
org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction;
public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG =
LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -111,6 +112,8 @@ public class OOCInstructionParser extends InstructionParser
{
return
DataGenOOCInstruction.parseInstruction(str);
case Append:
return
AppendOOCInstruction.parseInstruction(str);
+ case Quaternary:
+ return
QuaternaryOOCInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid OOC
Instruction Type: " + ooctype);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
index 4dcdffcb0d..d6686c1156 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
@@ -24,7 +24,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class ComputationOOCInstruction extends OOCInstruction {
public CPOperand output;
- public CPOperand input1, input2, input3;
+ public CPOperand input1, input2, input3, input4;
protected ComputationOOCInstruction(OOCType type, Operator op,
CPOperand in1, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
@@ -50,6 +50,15 @@ public abstract class ComputationOOCInstruction extends
OOCInstruction {
output = out;
}
+ protected ComputationOOCInstruction(OOCType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out,
String opcode, String istr) {
+ super(type, op, opcode, istr);
+ input1 = in1;
+ input2 = in2;
+ input3 = in3;
+ input4 = in4;
+ output = out;
+ }
+
public String getOutputVariableName() {
return output.getName();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index be9728d87b..679e7187e5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {
public enum OOCType {
Reblock, Tee, Binary, Ternary, Unary, AggregateUnary,
AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
- MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing,
ParameterizedBuiltin, Rand, Append
+ MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing,
ParameterizedBuiltin, Rand, Append, Quaternary
}
protected final OOCInstruction.OOCType _ooctype;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
new file mode 100644
index 0000000000..8df1e33c59
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+public abstract class QuaternaryOOCInstruction extends
ComputationOOCInstruction {
+
+ protected QuaternaryOOCInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand in4,
+ CPOperand out, String opcode, String istr) {
+ super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode,
istr);
+ }
+
+ public static QuaternaryOOCInstruction parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) {
+ InstructionUtils.checkNumFields(parts, 6);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand out = new CPOperand(parts[5]);
+ QuaternaryOperator qop = new
QuaternaryOperator(WDivMMType.valueOf(parts[6]));
+ return new WDivMMOOCInstruction(qop, in1, in2, in3,
in4, out, opcode, str);
+ }
+ throw new DMLRuntimeException("Not implemented yet opcode " +
opcode);
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
new file mode 100644
index 0000000000..ec9a7bcd4f
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.function.Function;
+
+public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {
+
+ protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand in4,
+ CPOperand out, String opcode, String istr) {
+ super(op, in1, in2, in3, in4, out, opcode, istr);
+ }
+
+ public static WDivMMOOCInstruction
parseInstruction(QuaternaryOOCInstruction instr) {
+ String instrStr = instr.getInstructionString();
+ String opcode =
InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0];
+ return new WDivMMOOCInstruction((QuaternaryOperator)
instr.getOperator(), instr.input1, instr.input2,
+ instr.input3, instr.input4, instr.output, opcode,
instrStr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ QuaternaryOperator qop = ((QuaternaryOperator) _optr);
+ final WDivMMType wt = qop.wtype3;
+
+ CachingStream X = new
CachingStream(ec.getMatrixObject(input1).getStreamHandle());
+ CachingStream U = new
CachingStream(ec.getMatrixObject(input2).getStreamHandle());
+ CachingStream V = new
CachingStream(ec.getMatrixObject(input3).getStreamHandle());
+
+ boolean basic = wt.isBasic();
+ boolean left = wt.isLeft();
+ boolean mult = wt.isMult();
+ boolean minus = wt.isMinus();
+ boolean four = wt.hasFourInputs();
+ boolean scalar = wt.hasScalar();
+
+ OOCStream<IndexedMatrixValue> mmt =
matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(),
+ V.getDataCharacteristics(), false, true);
+ OOCStream<IndexedMatrixValue> inter;
+ OOCStream<IndexedMatrixValue> out;
+
+ if(basic) {
+ out = elemMultOOC(X.getReadStream(), mmt);
+ ec.getMatrixObject(output).setStreamHandle(out);
+ return;
+ }
+ else if(four) {
+ if(scalar) {
+ double eps =
ec.getScalarInput(input4).getDoubleValue();
+ inter = elemDivOOC(X.getReadStream(),
elemPlusOOC(mmt, eps));
+ }
+ else {
+ CachingStream W = new
CachingStream(ec.getMatrixObject(input4).getStreamHandle());
+ inter = elemMultOOC(X.getReadStream(),
elemMinusOOC(mmt, W.getReadStream()));
+ }
+ }
+ else {
+ if(minus)
+ inter = maskOOC(X.getReadStream(),
elemMinusOOC(mmt, X.getReadStream()));
+ else {
+ if(mult)
+ inter = elemMultOOC(X.getReadStream(),
mmt);
+ else
+ inter = elemDivOOC(X.getReadStream(),
mmt);
+ }
+ }
+
+ if(left)
+ out = matMultOOC(inter, U.getReadStream(),
X.getDataCharacteristics(), U.getDataCharacteristics(),
+ true, false);
+ else
+ out = matMultOOC(inter, V.getReadStream(),
X.getDataCharacteristics(), V.getDataCharacteristics(),
+ false, false);
+
+ ec.getMatrixObject(output).setStreamHandle(out);
+ }
+
+ private OOCStream<IndexedMatrixValue>
matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
+ DataCharacteristics dc1, DataCharacteristics dc2, boolean
leftTranspose, boolean rightTranspose) {
+
+ int emitLeftThreshold = rightTranspose ? (int)
dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
+ int emitRightThreshold = leftTranspose ? (int)
dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
+
+ OOCStream<IndexedMatrixValue> intermediateStream =
createWritableStream();
+ OOCStream<IndexedMatrixValue> out = createWritableStream();
+
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateBinaryOperator op = new
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+
+ joinManyOOC(m1, m2, intermediateStream, (left, right) -> {
+ MatrixBlock leftBlock = (MatrixBlock) left.getValue();
+ MatrixBlock rightBlock = (MatrixBlock) right.getValue();
+ if(leftTranspose)
+ leftBlock = leftBlock.transpose();
+ if(rightTranspose)
+ rightBlock = rightBlock.transpose();
+
+ MatrixBlock partialResult =
leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(),
op);
+ int lidx = (int) (leftTranspose ?
left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
+ int ridx = (int) (rightTranspose ?
right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
+ return new IndexedMatrixValue(new MatrixIndexes(lidx,
ridx), partialResult);
+ }, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() :
tmp.getIndexes().getColumnIndex(),
+ tmp -> rightTranspose ?
tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
+ emitLeftThreshold, emitRightThreshold);
+
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
+ int emitAggThreshold = leftTranspose ? (int)
dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
+
+ groupedReduceOOC(intermediateStream, out, (left, right) -> {
+ MatrixBlock mb = ((MatrixBlock)
left.getValue()).binaryOperationsInPlace(plus, right.getValue());
+ left.setValue(mb);
+ return left;
+ }, emitAggThreshold);
+
+ return out;
+ }
+
+ private OOCStream<IndexedMatrixValue>
elemOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
BinaryOperator bop) {
+ SubscribableTaskQueue<IndexedMatrixValue> out = new
SubscribableTaskQueue<>();
+ Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
+ new MatrixIndexes(imv.getIndexes().getRowIndex(),
imv.getIndexes().getColumnIndex());
+
+ joinOOC(m1, m2, out, (left, right) -> {
+ MatrixBlock lb = (MatrixBlock) left.getValue();
+ MatrixBlock rb = (MatrixBlock) right.getValue();
+ MatrixBlock combined = lb.binaryOperations(bop, rb);
+ return new IndexedMatrixValue(
+ new
MatrixIndexes(left.getIndexes().getRowIndex(),
left.getIndexes().getColumnIndex()), combined);
+ }, key);
+
+ return out;
+ }
+
+ private OOCStream<IndexedMatrixValue>
elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
+ BinaryOperator div =
InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
+ return elemOOC(m1, m2, div);
+ }
+
+ private OOCStream<IndexedMatrixValue>
elemMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2)
{
+ BinaryOperator div =
InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString());
+ return elemOOC(m1, m2, div);
+ }
+
+ private OOCStream<IndexedMatrixValue>
elemMinusOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue>
m2) {
+ BinaryOperator div =
InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString());
+ return elemOOC(m1, m2, div);
+ }
+
+ private OOCStream<IndexedMatrixValue>
elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps) {
+ SubscribableTaskQueue<IndexedMatrixValue> out = new
SubscribableTaskQueue<>();
+ mapOOC(m1, out, blk -> {
+ MatrixBlock res = ((MatrixBlock) blk.getValue())
+ .scalarOperations(new
RightScalarOperator(Plus.getPlusFnObject(), eps), null);
+ return new IndexedMatrixValue(
+ new
MatrixIndexes(blk.getIndexes().getRowIndex(),
blk.getIndexes().getColumnIndex()), res);
+ });
+ return out;
+ }
+
+ private OOCStream<IndexedMatrixValue>
maskOOC(OOCStream<IndexedMatrixValue> mask, OOCStream<IndexedMatrixValue> m1) {
+ SubscribableTaskQueue<IndexedMatrixValue> out = new
SubscribableTaskQueue<>();
+ Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
+ new MatrixIndexes(imv.getIndexes().getRowIndex(),
imv.getIndexes().getColumnIndex());
+
+ joinOOC(mask, m1, out, (left, right) -> {
+ MatrixBlock lb = (MatrixBlock) left.getValue();
+ MatrixBlock rb = (MatrixBlock) right.getValue();
+ MatrixBlock combined = mask(lb, rb);
+ return new IndexedMatrixValue(
+ new
MatrixIndexes(left.getIndexes().getRowIndex(),
left.getIndexes().getColumnIndex()), combined);
+ }, key);
+
+ return out;
+ }
+
+ private MatrixBlock mask(MatrixBlock mask, MatrixBlock blk) {
+ for(int i = 0; i < blk.getNumRows(); i++) {
+ for(int j = 0; j < blk.getNumColumns(); j++) {
+ if(mask.get(i,j) ==0) blk.set(i, j, 0);
+ }
+ }
+ return blk;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
index a25249985d..d7186f2bbe 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
@@ -21,12 +21,16 @@ package org.apache.sysds.test.functions.ooc;
import java.io.IOException;
+import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
public class PNMFTest extends AutomatedTestBase {
private static final String TEST_NAME = "PNMF";
@@ -44,6 +48,7 @@ public class PNMFTest extends AutomatedTestBase {
private static final int RANK = 20;
private static final int MAX_ITER = 10;
private static final int BLOCK_SIZE = 1000;
+ private static final int SEED = 7;
private static final double SPARSITY = 0.7;
private static final double EPS = 1e-6;
@@ -54,7 +59,7 @@ public class PNMFTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
}
- //@Test
+ @Test
public void testPNMFOOCVsCP() {
runPNMFTest();
}
@@ -71,13 +76,16 @@ public class PNMFTest extends AutomatedTestBase {
double[][] xData = getRandomMatrix(ROWS, COLS, 1, 10,
SPARSITY, 7);
writeBinaryWithMTD(INPUT_X,
DataConverter.convertToMatrixBlock(xData));
- programArgs = new String[] {"-explain", "-stats",
"-seed", "7", "-ooc", "-args",
- input(INPUT_X), String.valueOf(RANK),
String.valueOf(MAX_ITER),
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args",
+ input(INPUT_X), String.valueOf(RANK),
String.valueOf(MAX_ITER), String.valueOf(SEED),
output(OUTPUT_W_OOC), output(OUTPUT_H_OOC)};
runTest(true, false, null, -1);
- programArgs = new String[] {"-explain", "-stats",
"-seed", "7", "-args",
- input(INPUT_X), String.valueOf(RANK),
String.valueOf(MAX_ITER),
+ Assert.assertTrue("OOC wasn't used for pnmf",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX +
Opcodes.WEIGHTEDDIVMM));
+
+ programArgs = new String[] {"-explain", "-stats",
"-args",
+ input(INPUT_X), String.valueOf(RANK),
String.valueOf(MAX_ITER), String.valueOf(SEED),
output(OUTPUT_W_CP), output(OUTPUT_H_CP)};
runTest(true, false, null, -1);
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java
new file mode 100644
index 0000000000..549fdc764d
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
[email protected]
+public class WDivMMTest extends AutomatedTestBase {
+ private final static String INPUT_NAME_1 = "W";
+ private final static String INPUT_NAME_2 = "U";
+ private final static String INPUT_NAME_3 = "V";
+ private final static String OUTPUT_NAME = "R";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
WDivMMTest.class.getSimpleName() + "/";
+
+ private static final int rows = 2201;
+ private static final int cols = 1103;
+ private static final int rank = 20;
+ private static final int blen = 1000;
+ private static final double eps = 1e-6;
+ private static final double div_eps = 0.1;
+
+ private final static String TEST_NAME_1 = "WeightedDivMMLeft";
+ private final static String TEST_NAME_2 = "WeightedDivMMRight";
+ private final static String TEST_NAME_3 = "WeightedDivMMMultBasic";
+ private final static String TEST_NAME_4 = "WeightedDivMMMultLeft";
+ private final static String TEST_NAME_5 = "WeightedDivMMMultRight";
+ private final static String TEST_NAME_6 = "WeightedDivMMMultMinusLeft";
+ private final static String TEST_NAME_7 = "WeightedDivMMMultMinusRight";
+ private final static String TEST_NAME_8 = "WeightedDivMM4MultMinusLeft";
+ private final static String TEST_NAME_9 =
"WeightedDivMM4MultMinusRight";
+ private final static String TEST_NAME_10 = "WeightedDivMMLeftEps";
+ private final static String TEST_NAME_11 = "WeightedDivMMRightEps";
+ private String TEST_NAME;
+
+ public WDivMMTest(String testName) {
+ this.TEST_NAME = testName;
+ }
+
+ @Parameterized.Parameters(name = "{0}")
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{TEST_NAME_1},
{TEST_NAME_2}, {TEST_NAME_3}, {TEST_NAME_4}, {TEST_NAME_5},
+ {TEST_NAME_6}, {TEST_NAME_7}, {TEST_NAME_8},
{TEST_NAME_9}, {TEST_NAME_10}, {TEST_NAME_11}});
+ }
+
+ @Before
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME}));
+ }
+
+ @Test
+ public void testWeightedDivMM() {
+ runWeightedDivMMTest(TEST_NAME);
+ }
+
+ private void runWeightedDivMMTest(String TEST_NAME) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ boolean basic = TEST_NAME.equals(TEST_NAME_3);
+ boolean left = TEST_NAME.equals(TEST_NAME_1) ||
TEST_NAME.equals(TEST_NAME_4) ||
+ TEST_NAME.equals(TEST_NAME_6) ||
TEST_NAME.equals(TEST_NAME_8) || TEST_NAME.equals(TEST_NAME_10);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+ double[][] W = getRandomMatrix(rows, cols, 0, 1, 0.7,
7);
+ double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0,
713);
+ double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0,
812);
+
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(W),
input(INPUT_NAME_1), rows,
+ cols, blen, rows * cols);
+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(U),
input(INPUT_NAME_2), rows,
+ rank, blen, rows * rank);
+
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(V),
input(INPUT_NAME_3), cols,
+ rank, blen, cols * rank);
+
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, blen,
rows * cols), Types.FileFormat.BINARY);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, rank, blen,
rows * rank), Types.FileFormat.BINARY);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_3 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(cols, rank, blen,
cols * rank), Types.FileFormat.BINARY);
+
+ programArgs = new String[] {"-ooc", "-stats",
"-explain", "runtime", "-args", input(INPUT_NAME_1),
+ input(INPUT_NAME_2), input(INPUT_NAME_3),
output(OUTPUT_NAME), Double.toString(div_eps)};
+
+ runTest(true, false, null, -1);
+
+ Assert.assertTrue("OOC wasn't used for wdivmm",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX +
Opcodes.WEIGHTEDDIVMM));
+
+ programArgs = new String[] {"-stats", "-explain",
"runtime", "-args", input(INPUT_NAME_1),
+ input(INPUT_NAME_2), input(INPUT_NAME_3),
output(OUTPUT_NAME + "_target"), Double.toString(div_eps)};
+
+ runTest(true, false, null, -1);
+
+ int rows2 = left ? cols : rows;
+ int cols2 = basic ? cols : rank;
+ checkDMLMetaDataFile("R", new
MatrixCharacteristics(rows2, cols2));
+
+ MatrixBlock actual =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, rows2, cols2, blen);
+ MatrixBlock expected =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+ Types.FileFormat.BINARY, rows2, cols2, blen);
+ TestUtils.compareMatrices(expected, actual, eps);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/PNMF.dml
index 60aecb8963..bc0fd5b100 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/PNMF.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
+[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE, seed=$4);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+write(W, $5, format="binary");
+write(H, $6, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
similarity index 86%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
index 60aecb8963..42bd4c96a0 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+X = W/0.7;
+while(FALSE){}
+R = t(t(U) %*% (W*(U%*%t(V)-X)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
index 60aecb8963..7b393f1231 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+X = W/0.3
+while(FALSE){}
+R = (W*(U%*%t(V)-X)) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
index 60aecb8963..48639a176a 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% (W/(U%*%t(V))));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
index 60aecb8963..dc07670fea 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+x = $5;
+
+R = t(t(U) %*% (W/(U%*%t(V) + x)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
index 60aecb8963..144e59a773 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = W*(U%*%t(V));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
index 60aecb8963..93bc765617 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% (W*(U%*%t(V))));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
index 60aecb8963..84ac35ad89 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% ((W != 0)*(U%*%t(V)-W)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
index 60aecb8963..59caa4d17b 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = ((W != 0)*(U%*%t(V)-W)) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
index 60aecb8963..fbb1224d17 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = (W*(U%*%t(V))) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMRight.dml
index 60aecb8963..e878a81d14 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = (W/(U%*%t(V))) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml
b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
index 60aecb8963..9ecbaf5663 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
#
#-------------------------------------------------------------
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+x = $5;
+
+R = (W/(U%*%t(V) + x)) %*% V;
+
+write(R, $4, format="binary");