This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 79122eb1f3 [SYSTEMDS-3891] New out-of-core instructions and
improvements
79122eb1f3 is described below
commit 79122eb1f3d6ea9d9b4457a8d1550e7b032a0707
Author: Jannik Lindemann <[email protected]>
AuthorDate: Sun Nov 30 11:46:36 2025 +0100
[SYSTEMDS-3891] New out-of-core instructions and improvements
Closes #2362.
---
src/main/java/org/apache/sysds/hops/DataOp.java | 3 +-
src/main/java/org/apache/sysds/lops/DataGen.java | 2 +-
src/main/java/org/apache/sysds/lops/Transform.java | 2 +-
.../controlprogram/caching/MatrixObject.java | 31 +++-
.../runtime/instructions/OOCInstructionParser.java | 4 +-
.../instructions/ooc/BinaryOOCInstruction.java | 57 ++++++-
.../runtime/instructions/ooc/CachingStream.java | 22 ++-
.../instructions/ooc/DataGenOOCInstruction.java | 186 ++++++++++++++++++++-
.../runtime/instructions/ooc/OOCInstruction.java | 96 +++++++++++
.../instructions/ooc/ReorgOOCInstruction.java | 124 ++++++++++++++
.../instructions/ooc/TSMMOOCInstruction.java | 29 ++--
.../instructions/ooc/TransposeOOCInstruction.java | 67 --------
.../apache/sysds/test/functions/ooc/RandTest.java | 103 ++++++++++++
.../apache/sysds/test/functions/ooc/SortTest.java | 128 ++++++++++++++
.../test/functions/ooc/StreamCollectTest.java | 4 +-
src/test/scripts/functions/ooc/Rand1.dml | 24 +++
src/test/scripts/functions/ooc/Rand2.dml | 24 +++
src/test/scripts/functions/ooc/Sort.dml | 27 +++
18 files changed, 826 insertions(+), 107 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java
b/src/main/java/org/apache/sysds/hops/DataOp.java
index cfd6630f27..7b912bd39e 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -297,6 +297,7 @@ public class DataOp extends Hop {
case TEE:
l = new Tee(getInput(0).constructLops(),
getDataType(), getValueType());
+ setOutputDimensions(l);
break;
default:
@@ -488,7 +489,7 @@ public class DataOp extends Hop {
@Override
public void refreshSizeInformation() {
- if( _op == OpOpData.PERSISTENTWRITE || _op ==
OpOpData.TRANSIENTWRITE ) {
+ if( _op == OpOpData.PERSISTENTWRITE || _op ==
OpOpData.TRANSIENTWRITE || _op == OpOpData.TEE ) {
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
setDim2(input1.getDim2());
diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java
b/src/main/java/org/apache/sysds/lops/DataGen.java
index d8ac1b8a4a..93237f9b7f 100644
--- a/src/main/java/org/apache/sysds/lops/DataGen.java
+++ b/src/main/java/org/apache/sysds/lops/DataGen.java
@@ -199,7 +199,7 @@ public class DataGen extends Lop
sb.append(iLop == null ? "" : iLop.prepScalarLabel());
sb.append(OPERAND_DELIMITOR);
- if( getExecType() == ExecType.CP ) {
+ if( getExecType() == ExecType.CP || getExecType() ==
ExecType.OOC ) {
//append degree of parallelism
sb.append( _numThreads );
sb.append( OPERAND_DELIMITOR );
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java
b/src/main/java/org/apache/sysds/lops/Transform.java
index 0ac36a37e4..0d2e79f83a 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -179,7 +179,7 @@ public class Transform extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append( this.prepOutputOperand(output));
- if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
+ if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED
|| getExecType()==ExecType.OOC)
&& (_operation == ReOrgOp.TRANS || _operation ==
ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 496bca8764..8191040eb1 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.caching;
import java.io.IOException;
import java.lang.ref.SoftReference;
+import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
@@ -528,7 +529,12 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
@Override
protected MatrixBlock
readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws
IOException {
- MatrixBlock ret = new MatrixBlock((int)getNumRows(),
(int)getNumColumns(), false);
+ boolean dimsUnknown = getNumRows() < 0 || getNumColumns() < 0;
+ int nrows = (int)getNumRows();
+ int ncols = (int)getNumColumns();
+ MatrixBlock ret = dimsUnknown ? null : new
MatrixBlock((int)getNumRows(), (int)getNumColumns(), false);
+ // TODO if stream is CachingStream, block parts might be
evicted resulting in null pointer exceptions
+ List<IndexedMatrixValue> blockCache = dimsUnknown ? new
ArrayList<>() : null;
IndexedMatrixValue tmp = null;
try {
int blen = getBlocksize(), lnnz = 0;
@@ -537,12 +543,31 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
final int row_offset = (int)
(tmp.getIndexes().getRowIndex() - 1) * blen;
final int col_offset = (int)
(tmp.getIndexes().getColumnIndex() - 1) * blen;
- // Add the values of this block into the output
block.
- ((MatrixBlock)tmp.getValue()).putInto(ret,
row_offset, col_offset, true);
+ if (dimsUnknown) {
+ nrows = Math.max(nrows, row_offset +
tmp.getValue().getNumRows());
+ ncols = Math.max(ncols, col_offset +
tmp.getValue().getNumColumns());
+ blockCache.add(tmp);
+ } else {
+ // Add the values of this block into
the output block.
+ ((MatrixBlock)
tmp.getValue()).putInto(ret, row_offset, col_offset, true);
+ }
// incremental maintenance nnz
lnnz += tmp.getValue().getNonZeros();
}
+
+ if (dimsUnknown) {
+ ret = new MatrixBlock(nrows, ncols, false);
+
+ for (IndexedMatrixValue _tmp : blockCache) {
+ // compute row/column block offsets
+ final int row_offset = (int)
(_tmp.getIndexes().getRowIndex() - 1) * blen;
+ final int col_offset = (int)
(_tmp.getIndexes().getColumnIndex() - 1) * blen;
+
+ ((MatrixBlock)
_tmp.getValue()).putInto(ret, row_offset, col_offset, true);
+ }
+ }
+
ret.setNonZeros(lnnz);
}
catch(Exception ex) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index f23ad6d67a..feefe5f63d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -36,7 +36,7 @@ import
org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import
org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
-import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
public class OOCInstructionParser extends InstructionParser {
@@ -74,7 +74,7 @@ public class OOCInstructionParser extends InstructionParser {
case MMTSJ:
return TSMMOOCInstruction.parseInstruction(str);
case Reorg:
- return
TransposeOOCInstruction.parseInstruction(str);
+ return
ReorgOOCInstruction.parseInstruction(str);
case Tee:
return TeeOOCInstruction.parseInstruction(str);
case CentralMoment:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
index 148592b6a9..01c7a525bc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.ooc;
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -67,12 +69,55 @@ public class BinaryOOCInstruction extends
ComputationOOCInstruction {
OOCStream<IndexedMatrixValue> qOut = new
SubscribableTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
- joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
- IndexedMatrixValue tmpOut = new IndexedMatrixValue();
- tmpOut.set(tmp1.getIndexes(),
-
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(),
tmpOut.getValue()));
- return tmpOut;
- }, IndexedMatrixValue::getIndexes);
+ if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 ||
m2.getNumRows() < 0 || m2.getNumColumns() < 0)
+ throw new DMLRuntimeException("Cannot process (matrix,
matrix) BinaryOOCInstruction with unknown dimensions.");
+
+ boolean isColBroadcast = m1.getNumColumns() > 1 &&
m2.getNumColumns() == 1;
+ boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows()
== 1;
+
+ if (isColBroadcast && !isRowBroadcast) {
+ final long maxProcessesPerBroadcast =
m1.getNumColumns() / m1.getBlocksize();
+
+ broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
+ IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
+ tmpOut.set(tmp1.getIndexes(),
+
tmp1.getValue().binaryOperations((BinaryOperator)_optr,
b.getValue().getValue(), tmpOut.getValue()));
+
+ if (b.incrProcessCtrAndGet() >=
maxProcessesPerBroadcast)
+ b.release();
+
+ return tmpOut;
+ }, tmp -> tmp.getIndexes().getRowIndex());
+ }
+ else if (isRowBroadcast && !isColBroadcast) {
+ final long maxProcessesPerBroadcast = m1.getNumRows() /
m1.getBlocksize();
+
+ broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
+ IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
+ tmpOut.set(tmp1.getIndexes(),
+
tmp1.getValue().binaryOperations((BinaryOperator)_optr,
b.getValue().getValue(), tmpOut.getValue()));
+
+ if (b.incrProcessCtrAndGet() >=
maxProcessesPerBroadcast)
+ b.release();
+
+ return tmpOut;
+ }, tmp -> tmp.getIndexes().getColumnIndex());
+ }
+ else {
+ if (m1.getNumColumns() != m2.getNumColumns() ||
m1.getNumRows() != m2.getNumRows())
+ throw new NotImplementedException("Invalid
dimensions for matrix-matrix binary op: "
+ + m1.getNumRows() + "x" +
m1.getNumColumns() + " <=> "
+ + m2.getNumRows() + "x" +
m2.getNumColumns());
+
+ joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
+ IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
+ tmpOut.set(tmp1.getIndexes(),
+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(),
tmpOut.getValue()));
+ return tmpOut;
+ }, IndexedMatrixValue::getIndexes);
+ }
+
+
}
protected void processScalarMatrixInstruction(ExecutionContext ec) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index b74f7ed5e1..d7c80e4de3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -52,6 +52,8 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
private boolean _cacheInProgress = true; // caching in progress, in the
first pass.
private Map<MatrixIndexes, Integer> _index;
+ private DMLRuntimeException _failure;
+
public CachingStream(OOCStream<IndexedMatrixValue> source) {
this(source, _streamSeq.getNextID());
}
@@ -76,6 +78,22 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
}
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
+ } catch (DMLRuntimeException e) {
+ // Propagate failure to subscribers
+ _failure = e;
+ synchronized (this) {
+ notifyAll();
+ }
+
+ Runnable[] mSubscribers = _subscribers;
+ if(mSubscribers != null) {
+ for(Runnable mSubscriber :
mSubscribers) {
+ try {
+ mSubscriber.run();
+ } catch (Exception ignored) {
+ }
+ }
+ }
}
});
}
@@ -103,7 +121,9 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
public synchronized IndexedMatrixValue get(int idx) throws
InterruptedException {
while (true) {
- if (idx < _numBlocks) {
+ if (_failure != null)
+ throw _failure;
+ else if (idx < _numBlocks) {
IndexedMatrixValue out =
OOCEvictionManager.get(_streamId, idx);
if (_index != null) // Ensure index is up to
date
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java
index 355c8ddea1..81b3bb7b38 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/DataGenOOCInstruction.java
@@ -20,8 +20,12 @@
package org.apache.sysds.runtime.instructions.ooc;
import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.DataGenOp;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -30,25 +34,88 @@ import
org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.UtilFunctions;
public class DataGenOOCInstruction extends UnaryOOCInstruction {
+ private static final Log LOG =
LogFactory.getLog(DataGenOOCInstruction.class.getName());
+ private Types.OpOpDG method;
+ private final CPOperand rows, cols; //, dims;
private final int blen;
- private Types.OpOpDG method;
+ //private boolean minMaxAreDoubles;
+ private final String minValueStr, maxValueStr;
+ private final double minValue, maxValue, sparsity;
+ private final String pdf, pdfParams; //, frame_data, schema;
+ private final long seed;
+ private Long runtimeSeed;
// sequence specific attributes
private final CPOperand seq_from, seq_to, seq_incr;
- public DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd,
CPOperand in, CPOperand out, int blen, CPOperand seqFrom,
- CPOperand seqTo, CPOperand seqIncr, String opcode, String istr)
{
- super(OOCType.Rand, op, in, out, opcode, istr);
- this.blen = blen;
+ // sample specific attributes
+ //private final boolean replace;
+ //private final int numThreads;
+
+ // seed positions
+ private static final int SEED_POSITION_RAND = 8;
+ //private static final int SEED_POSITION_SAMPLE = 4;
+
+ private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd,
CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
+ CPOperand dims, int blen, String minValue, String maxValue,
double sparsity, long seed,
+ String probabilityDensityFunction, String pdfParams, int k,
CPOperand seqFrom, CPOperand seqTo,
+ CPOperand seqIncr, boolean replace, String data, String schema,
String opcode, String istr) {
+ super(OOCInstruction.OOCType.Rand, op, in, out, opcode, istr);
this.method = mthd;
+ this.rows = rows;
+ this.cols = cols;
+ //this.dims = dims;
+ this.blen = blen;
+ this.minValueStr = minValue;
+ this.maxValueStr = maxValue;
+ double minDouble, maxDouble;
+ try {
+ minDouble =
!minValue.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(minValue) :
-1;
+ maxDouble =
!maxValue.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double.valueOf(maxValue) :
-1;
+ //minMaxAreDoubles = true;
+ }
+ catch(NumberFormatException e) {
+ // Non double values
+ if(!minValueStr.equals(maxValueStr)) {
+ throw new DMLRuntimeException(
+ "Rand instruction does not support " +
"non numeric Datatypes for range initializations.");
+ }
+ minDouble = -1;
+ maxDouble = -1;
+ //minMaxAreDoubles = false;
+ }
+ this.minValue = minDouble;
+ this.maxValue = maxDouble;
+ this.sparsity = sparsity;
+ this.seed = seed;
+ this.pdf = probabilityDensityFunction;
+ this.pdfParams = pdfParams;
+ //this.numThreads = k;
this.seq_from = seqFrom;
this.seq_to = seqTo;
this.seq_incr = seqIncr;
+ //this.replace = replace;
+ //this.frame_data = data;
+ //this.schema = schema;
+ }
+
+ private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd,
CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
+ CPOperand dims, int blen, CPOperand seqFrom, CPOperand seqTo,
CPOperand seqIncr, String opcode, String istr) {
+ this(op, mthd, in, out, rows, cols, dims, blen, "0", "1", 1.0,
-1, null, null, 1, seqFrom, seqTo, seqIncr,
+ false, null, null, opcode, istr);
+ }
+
+ private DataGenOOCInstruction(UnaryOperator op, Types.OpOpDG mthd,
CPOperand in, CPOperand out, CPOperand rows, CPOperand cols,
+ CPOperand dims, int blen, String minValue, String maxValue,
double sparsity, long seed,
+ String probabilityDensityFunction, String pdfParams, int k,
String opcode, String istr) {
+ this(op, mthd, in, out, rows, cols, dims, blen, minValue,
maxValue, sparsity, seed, probabilityDensityFunction,
+ pdfParams, k, null, null, null, false, null, null,
opcode, istr);
}
public static DataGenOOCInstruction parseInstruction(String str) {
@@ -56,7 +123,11 @@ public class DataGenOOCInstruction extends
UnaryOOCInstruction {
String[] s =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = s[0];
- if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) {
+ if(opcode.equalsIgnoreCase(Opcodes.RANDOM.toString())) {
+ method = Types.OpOpDG.RAND;
+ InstructionUtils.checkNumFields(s, 10, 11);
+ }
+ else if(opcode.equalsIgnoreCase(Opcodes.SEQUENCE.toString())) {
method = Types.OpOpDG.SEQ;
// 8 operands: rows, cols, blen, from, to, incr, outvar
InstructionUtils.checkNumFields(s, 7);
@@ -67,13 +138,37 @@ public class DataGenOOCInstruction extends
UnaryOOCInstruction {
CPOperand out = new CPOperand(s[s.length - 1]);
UnaryOperator op = null;
- if(method == Types.OpOpDG.SEQ) {
+ if(method == Types.OpOpDG.RAND) {
+ int missing; // number of missing params (row & cols or
dims)
+ CPOperand rows = null, cols = null, dims = null;
+ if(s.length == 12) {
+ missing = 1;
+ rows = new CPOperand(s[1]);
+ cols = new CPOperand(s[2]);
+ }
+ else {
+ missing = 2;
+ dims = new CPOperand(s[1]);
+ }
+ int blen = Integer.parseInt(s[4 - missing]);
+ double sparsity = !s[7 -
missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Double
+ .parseDouble(s[7 - missing]) : -1;
+ long seed = !s[SEED_POSITION_RAND -
missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? Long
+ .parseLong(s[SEED_POSITION_RAND - missing]) :
-1;
+ String pdf = s[9 - missing];
+ String pdfParams = !s[10 -
missing].contains(Lop.VARIABLE_NAME_PLACEHOLDER) ? s[10 - missing] : null;
+ int k = Integer.parseInt(s[11 - missing]);
+
+ return new DataGenOOCInstruction(op, method, null, out,
rows, cols, dims, blen, s[5 - missing],
+ s[6 - missing], sparsity, seed, pdf, pdfParams,
k, opcode, str);
+ }
+ else if(method == Types.OpOpDG.SEQ) {
int blen = Integer.parseInt(s[3]);
CPOperand from = new CPOperand(s[4]);
CPOperand to = new CPOperand(s[5]);
CPOperand incr = new CPOperand(s[6]);
- return new DataGenOOCInstruction(op, method, null, out,
blen, from, to, incr, opcode, str);
+ return new DataGenOOCInstruction(op, method, null, out,
null, null, null, blen, from, to, incr, opcode, str);
}
else
throw new NotImplementedException();
@@ -84,7 +179,45 @@ public class DataGenOOCInstruction extends
UnaryOOCInstruction {
final OOCStream<IndexedMatrixValue> qOut =
createWritableStream();
// process specific datagen operator
- if(method == Types.OpOpDG.SEQ) {
+ if (method == Types.OpOpDG.RAND) {
+ if (!output.isMatrix())
+ throw new NotImplementedException();
+
+ long lSeed = generateSeed();
+ long lrows = ec.getScalarInput(rows).getLongValue();
+ long lcols = ec.getScalarInput(cols).getLongValue();
+ checkValidDimensions(lrows, lcols);
+
+ if (!pdf.equalsIgnoreCase("uniform") || minValue !=
maxValue)
+ throw new NotImplementedException(); // TODO
modified version of rng as in LibMatrixDatagen to handle blocks independently
+
+ OOCStream<MatrixIndexes> qIn = createWritableStream();
+ int nrb = (int)((lrows-1) / blen)+1;
+ int ncb = (int)((lcols-1) / blen)+1;
+
+ for (int row = 0; row < nrb; row++)
+ for (int col = 0; col < ncb; col++)
+ qIn.enqueue(new MatrixIndexes(row+1,
col+1));
+
+ qIn.closeInput();
+
+ if(sparsity == 0.0 && lrows < Integer.MAX_VALUE &&
lcols < Integer.MAX_VALUE) {
+ mapOOC(qIn, qOut, idx -> {
+ long rlen = Math.min(blen, lrows -
(idx.getRowIndex()-1) * blen);
+ long clen = Math.min(blen, lcols -
(idx.getColumnIndex()-1) * blen);
+ return new IndexedMatrixValue(idx, new
MatrixBlock((int)rlen, (int)clen, 0.0));
+ });
+ return;
+ }
+
+ mapOOC(qIn, qOut, idx -> {
+ long rlen = Math.min(blen, lrows -
(idx.getRowIndex()-1) * blen);
+ long clen = Math.min(blen, lcols -
(idx.getColumnIndex()-1) * blen);
+ MatrixBlock mout =
MatrixBlock.randOperations(getGenerator(rlen, clen), lSeed);
+ return new IndexedMatrixValue(idx, mout);
+ });
+ }
+ else if(method == Types.OpOpDG.SEQ) {
double lfrom =
ec.getScalarInput(seq_from).getDoubleValue();
double lto = ec.getScalarInput(seq_to).getDoubleValue();
double lincr =
ec.getScalarInput(seq_incr).getDoubleValue();
@@ -133,4 +266,39 @@ public class DataGenOOCInstruction extends
UnaryOOCInstruction {
ec.getMatrixObject(output).setStreamHandle(qOut);
}
+
+
+
+ private long generateSeed() {
+ // generate pseudo-random seed (because not specified)
+ long lSeed = seed; // seed per invocation
+ if(lSeed == DataGenOp.UNSPECIFIED_SEED) {
+ if(runtimeSeed == null)
+ runtimeSeed = DataGenOp.generateRandomSeed();
+ lSeed = runtimeSeed;
+ }
+
+ if(LOG.isTraceEnabled())
+ LOG.trace("Process DataGenOOCInstruction rand with seed
= " + lSeed + ".");
+
+ return lSeed;
+ }
+
+ private static void checkValidDimensions(long rows, long cols) {
+ // check valid for integer dimensions (we cannot even represent
empty blocks with larger dimensions)
+ if(rows > Integer.MAX_VALUE || cols > Integer.MAX_VALUE)
+ throw new DMLRuntimeException("DataGenOOCInstruction
does not "
+ + "support dimensions larger than integer:
rows=" + rows + ", cols=" + cols + ".");
+ }
+
+ private RandomMatrixGenerator getGenerator(long lrows, long lcols) {
+ return LibMatrixDatagen.createRandomMatrixGenerator(pdf,
+ (int) lrows,
+ (int) lcols,
+ blen,
+ sparsity,
+ minValue,
+ maxValue,
+ pdfParams);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index eb7cf55f51..ca13cfdb2c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -38,8 +38,10 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@@ -150,6 +152,100 @@ public abstract class OOCInstruction extends Instruction {
}, qOut::closeInput);
}
+ protected <R, P> CompletableFuture<Void>
broadcastJoinOOC(OOCStream<IndexedMatrixValue> qIn,
OOCStream<IndexedMatrixValue> broadcast, OOCStream<R> qOut,
BiFunction<IndexedMatrixValue, BroadcastedElement, R> mapper,
Function<IndexedMatrixValue, P> on) {
+ addInStream(qIn, broadcast);
+ addOutStream(qOut);
+
+ boolean explicitLeftCaching = !qIn.hasStreamCache();
+ boolean explicitRightCaching = !broadcast.hasStreamCache();
+ CachingStream leftCache = explicitLeftCaching ? new
CachingStream(new SubscribableTaskQueue<>()) : qIn.getStreamCache();
+ CachingStream rightCache = explicitRightCaching ? new
CachingStream(new SubscribableTaskQueue<>()) : broadcast.getStreamCache();
+ leftCache.activateIndexing();
+ rightCache.activateIndexing();
+
+ Map<P, List<MatrixIndexes>> availableLeftInput = new
ConcurrentHashMap<>();
+ Map<P, BroadcastedElement> availableBroadcastInput = new
ConcurrentHashMap<>();
+
+ return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> {
+ P key = on.apply(tmp);
+
+ if (i == 0) { // qIn stream
+ BroadcastedElement b =
availableBroadcastInput.get(key);
+
+ if (b == null) {
+ // Matching broadcast element is not
available -> cache element
+ if (explicitLeftCaching)
+
leftCache.getWriteStream().enqueue(tmp);
+
+ availableLeftInput.compute(key, (k, v)
-> {
+ if (v == null)
+ v = new ArrayList<>();
+ v.add(tmp.getIndexes());
+ return v;
+ });
+ } else {
+ // Directly emit
+ qOut.enqueue(mapper.apply(tmp, b));
+
+ if (b.canRelease())
+
availableBroadcastInput.remove(key);
+ }
+ } else { // broadcast stream
+ if (explicitRightCaching)
+
rightCache.getWriteStream().enqueue(tmp);
+
+ BroadcastedElement b = new
BroadcastedElement(tmp.getIndexes());
+ availableBroadcastInput.put(key, b);
+
+ List<MatrixIndexes> queued =
availableLeftInput.remove(key);
+
+ if (queued != null) {
+ for(MatrixIndexes idx : queued) {
+ b.value =
rightCache.findCached(b.idx);
+
qOut.enqueue(mapper.apply(leftCache.findCached(idx), b));
+ b.value = null;
+ }
+ }
+
+ if (b.canRelease())
+ availableBroadcastInput.remove(key);
+ }
+ }, qOut::closeInput);
+ }
+
+ protected static class BroadcastedElement {
+ private final MatrixIndexes idx;
+ private IndexedMatrixValue value;
+ private boolean release;
+ private int processCtr;
+
+ public BroadcastedElement(MatrixIndexes idx) {
+ this.idx = idx;
+ this.release = false;
+ }
+
+ public synchronized void release() {
+ release = true;
+ }
+
+ public synchronized boolean canRelease() {
+ return release;
+ }
+
+ public synchronized int incrProcessCtrAndGet() {
+ processCtr++;
+ return processCtr;
+ }
+
+ public MatrixIndexes getIndex() {
+ return idx;
+ }
+
+ public IndexedMatrixValue getValue() {
+ return value;
+ }
+ };
+
protected <T, R, P> CompletableFuture<Void> joinOOC(OOCStream<T> qIn1,
OOCStream<T> qIn2, OOCStream<R> qOut, BiFunction<T, T, R> mapper, Function<T,
P> on) {
return joinOOC(qIn1, qIn2, qOut, mapper, on, on);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java
new file mode 100644
index 0000000000..a87a349832
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.SortIndex;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.util.DataConverter;
+
+public class ReorgOOCInstruction extends ComputationOOCInstruction {
+ // sort-specific attributes (to enable variable attributes)
+ private final CPOperand _col;
+ private final CPOperand _desc;
+ private final CPOperand _ixret;
+
+ protected ReorgOOCInstruction(ReorgOperator op, CPOperand in1,
CPOperand out, String opcode, String istr) {
+ this(op, in1, out, null, null, null, opcode, istr);
+ }
+
+ private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand col, CPOperand desc, CPOperand ixret,
+ String opcode, String istr) {
+ super(OOCType.Reorg, op, in, out, opcode, istr);
+ _col = col;
+ _desc = desc;
+ _ixret = ixret;
+ }
+
+ public static ReorgOOCInstruction parseInstruction(String str) {
+ CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if (opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) {
+ InstructionUtils.checkNumFields(str, 2, 3);
+ in.split(parts[1]);
+ out.split(parts[2]);
+
+ ReorgOperator reorg = new
ReorgOperator(SwapIndex.getSwapIndexFnObject());
+ return new ReorgOOCInstruction(reorg, in, out, opcode,
str);
+ }
+ else if (opcode.equalsIgnoreCase(Opcodes.SORT.toString())) {
+ InstructionUtils.checkNumFields(str, 5,6);
+ in.split(parts[1]);
+ out.split(parts[5]);
+ CPOperand col = new CPOperand(parts[2]);
+ CPOperand desc = new CPOperand(parts[3]);
+ CPOperand ixret = new CPOperand(parts[4]);
+ int k = Integer.parseInt(parts[6]);
+ return new ReorgOOCInstruction(new ReorgOperator(new
SortIndex(1,false,false), k),
+ in, out, col, desc, ixret, opcode, str);
+ }
+ else
+ throw new NotImplementedException();
+ }
+
+ public void processInstruction( ExecutionContext ec ) {
+ // Create thread and process the transpose operation
+ MatrixObject min = ec.getMatrixObject(input1);
+
+
+ ReorgOperator r_op = (ReorgOperator) _optr;
+
+ if(r_op.fn instanceof SortIndex) {
+ //additional attributes for sort
+ int[] cols = _col.getDataType().isMatrix() ?
DataConverter.convertToIntVector(ec.getMatrixInput(_col.getName())) :
+ new
int[]{(int)ec.getScalarInput(_col).getLongValue()};
+ boolean desc =
ec.getScalarInput(_desc).getBooleanValue();
+ boolean ixret =
ec.getScalarInput(_ixret).getBooleanValue();
+ r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
+
+ // For now, we reuse the CP instruction
+ // In future, we could optimize by building the
permutation and streaming blocks column by column
+ MatrixBlock matBlock = min.acquireRead();
+ MatrixBlock soresBlock = matBlock.reorgOperations(r_op,
new MatrixBlock(), 0, 0, 0);
+ if (_col.getDataType().isMatrix())
+ ec.releaseMatrixInput(_col.getName());
+ ec.releaseMatrixInput(input1.getName());
+ ec.setMatrixOutput(output.getName(), soresBlock);
+ } else if (r_op.fn instanceof SwapIndex) {
+ OOCStream<IndexedMatrixValue> qIn =
min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut =
createWritableStream();
+ ec.getMatrixObject(output).setStreamHandle(qOut);
+ // Transpose operation
+ mapOOC(qIn, qOut, tmp -> {
+ MatrixBlock inBlock = (MatrixBlock)
tmp.getValue();
+ long oldRowIdx = tmp.getIndexes().getRowIndex();
+ long oldColIdx =
tmp.getIndexes().getColumnIndex();
+
+ MatrixBlock outBlock =
inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1);
+ return new IndexedMatrixValue(new
MatrixIndexes(oldColIdx, oldRowIdx), outBlock);
+ });
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
index 9040c369a2..e020794040 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.ooc;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
@@ -76,19 +75,21 @@ public class TSMMOOCInstruction extends
ComputationOOCInstruction {
throw new UnsupportedOperationException();
}
- int dim = _type.isLeft() ? nCols : nRows;
- MatrixBlock resultBlock = new MatrixBlock(dim, dim, false);
- try {
- IndexedMatrixValue tmp = null;
- // aggregate partial tsmm outputs into result as inputs
stream in
- while((tmp = qIn.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- MatrixBlock partialResult = ((MatrixBlock)
tmp.getValue())
- .transposeSelfMatrixMultOperations(new
MatrixBlock(), _type);
- resultBlock.binaryOperationsInPlace(plus,
partialResult);
- }
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
+ //int dim = _type.isLeft() ? nCols : nRows;
+ MatrixBlock resultBlock = null;
+
+ OOCStream<MatrixBlock> tmpStream = createWritableStream();
+
+ mapOOC(qIn, tmpStream,
+ tmp -> ((MatrixBlock) tmp.getValue())
+ .transposeSelfMatrixMultOperations(new
MatrixBlock(), _type));
+
+ MatrixBlock tmp;
+ while ((tmp = tmpStream.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ if (resultBlock == null)
+ resultBlock = tmp;
+ else
+ resultBlock.binaryOperationsInPlace(plus, tmp);
}
ec.setMatrixOutput(output.getName(), resultBlock);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
deleted file mode 100644
index 6558145ec2..0000000000
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.instructions.ooc;
-
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-
-public class TransposeOOCInstruction extends ComputationOOCInstruction {
-
- protected TransposeOOCInstruction(OOCType type, ReorgOperator op,
CPOperand in1, CPOperand out, String opcode, String istr) {
- super(type, op, in1, out, opcode, istr);
-
- }
-
- public static TransposeOOCInstruction parseInstruction(String str) {
- String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 2);
- String opcode = parts[0];
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand out = new CPOperand(parts[2]);
-
- ReorgOperator reorg = new
ReorgOperator(SwapIndex.getSwapIndexFnObject());
- return new TransposeOOCInstruction(OOCType.Reorg, reorg, in1,
out, opcode, str);
- }
-
- public void processInstruction( ExecutionContext ec ) {
-
- // Create thread and process the transpose operation
- MatrixObject min = ec.getMatrixObject(input1);
- OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
- OOCStream<IndexedMatrixValue> qOut = createWritableStream();
- ec.getMatrixObject(output).setStreamHandle(qOut);
-
- mapOOC(qIn, qOut, tmp -> {
- MatrixBlock inBlock = (MatrixBlock) tmp.getValue();
- long oldRowIdx = tmp.getIndexes().getRowIndex();
- long oldColIdx = tmp.getIndexes().getColumnIndex();
-
- MatrixBlock outBlock =
inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1);
- return new IndexedMatrixValue(new
MatrixIndexes(oldColIdx, oldRowIdx), outBlock);
- });
- }
-}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java
new file mode 100644
index 0000000000..40430aa49f
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/RandTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class RandTest extends AutomatedTestBase {
+ private final static String TEST_NAME_1 = "Rand1";
+ private final static String TEST_NAME_2 = "Rand2";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
RandTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+ private static final String INPUT_NAME_1 = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 1500;
+ private final static int cols = 1200;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_1);
+ addTestConfiguration(TEST_NAME_1, config);
+ TestConfiguration config2 = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2);
+ addTestConfiguration(TEST_NAME_2, config2);
+ }
+
+ // Actual rand operation not yet supported
+ /*@Test
+ public void testRand() {
+ runRandTest(TEST_NAME_1);
+ }*/
+
+ @Test
+ public void testConstInit() {
+ runRandTest(TEST_NAME_2);
+ }
+
+ private void runRandTest(String TEST_NAME) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)};
+
+ runTest(true, false, null, -1);
+
+ //check replace OOC op
+ Assert.assertTrue("OOC wasn't used for rand",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.RANDOM));
+
+ //compare results
+
+ // rerun without ooc flag
+ programArgs = new String[] {"-explain", "-stats",
"-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")};
+ runTest(true, false, null, -1);
+
+ // compare matrices
+ MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java
new file mode 100644
index 0000000000..61fc35a0b5
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/SortTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class SortTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "Sort";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
SortTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+ private static final String INPUT_NAME_1 = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int maxVal = 7;
+ private final static double sparsity1 = 1;
+ private final static double sparsity2 = 0.05;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testSortDenseMatrix() {
+ runSortTest(1500, 800, false);
+ }
+
+ @Test
+ public void testSortSparseMatrix() {
+ runSortTest(1500, 800, true);
+ }
+
+ @Test
+ public void testSortDenseVector() {
+ runSortTest(1500, 1, false);
+ }
+
+ @Test
+ public void testSortSparseVector() {
+ runSortTest(1500, 1, true);
+ }
+
+ private void runSortTest(int rows, int cols, boolean sparse) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] X_data = getRandomMatrix(rows, cols, 1,
maxVal, sparse ? sparsity2 : sparsity1, 7);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock X_mb =
DataConverter.convertToMatrixBlock(X_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1),
rows, cols, 1000, X_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ //check sort OOC
+ Assert.assertTrue("OOC wasn't used for sort",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.SORT));
+
+ //compare results
+
+ // rerun without ooc flag
+ programArgs = new String[] {"-explain", "-stats",
"-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")};
+ runTest(true, false, null, -1);
+
+ // compare matrices
+ MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java
index 877a0dd0b1..b1f35e4a50 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/StreamCollectTest.java
@@ -26,7 +26,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
-import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -86,7 +86,7 @@ public class StreamCollectTest extends AutomatedTestBase {
VariableCPInstruction createOut =
VariableCPInstruction.parseInstruction(
"CP°createvar°_mVar1°" + input("tmp1") +
"°true°MATRIX°binary°" + rows + "°" + cols +
"°1000°700147°copy");
- TransposeOOCInstruction oocTranspose =
TransposeOOCInstruction.parseInstruction(
+ ReorgOOCInstruction oocTranspose =
ReorgOOCInstruction.parseInstruction(
"OOC°r'°_mVar0·MATRIX·FP64°_mVar1·MATRIX·FP64");
VariableCPInstruction createOut2 =
VariableCPInstruction.parseInstruction(
"CP°createvar°_mVar2°" + input("tmp2") +
"°true°MATRIX°binary°" + rows + "°" + cols +
diff --git a/src/test/scripts/functions/ooc/Rand1.dml
b/src/test/scripts/functions/ooc/Rand1.dml
new file mode 100644
index 0000000000..2861f29462
--- /dev/null
+++ b/src/test/scripts/functions/ooc/Rand1.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+res = rand(rows=1500, cols=1200, min=-1, max=1);
+
+write(res, $2, format="binary");
diff --git a/src/test/scripts/functions/ooc/Rand2.dml
b/src/test/scripts/functions/ooc/Rand2.dml
new file mode 100644
index 0000000000..033632edb0
--- /dev/null
+++ b/src/test/scripts/functions/ooc/Rand2.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+res = matrix(1, rows=1500, cols=1200);
+
+write(res, $2, format="binary");
diff --git a/src/test/scripts/functions/ooc/Sort.dml
b/src/test/scripts/functions/ooc/Sort.dml
new file mode 100644
index 0000000000..30ccfd03e9
--- /dev/null
+++ b/src/test/scripts/functions/ooc/Sort.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the input matrix as a stream
+X = read($1);
+
+res = order(target=X, by=1);
+
+write(res, $2, format="binary");