This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new f29ae42 [SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II
f29ae42 is described below
commit f29ae426be1722fba9468609976709068e6e5d7d
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri May 22 22:38:54 2020 +0200
[SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II
This patch fixes many additional issues in lineage tracing and parsing
in order to support the round-trip for steplm and kmeans.
1) Lineage tracing with default arguments of function call parameters
(so far missing arguments where traces as literal variable name)
2) Lineage Tracing: rshape with parameters, ctable w/ dimensions,
rand/seq w/ variable rows/cols, from/to/incr inputs
3) Lineage Parsing: rshape, rdiag, nrow, ncol, all casts ops, ifelse
with scalar/matrix inputs (so far block size wrong), ctable /w
dimensions, gappend spark ops
4) New lineage parfor algorithm tests: steplm, kmeans
---
scripts/builtin/kmeans.dml | 6 +--
src/main/java/org/apache/sysds/common/Types.java | 35 ++++++++++---
.../apache/sysds/hops/recompile/Recompiler.java | 8 +--
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 37 ++++++++------
.../RewriteAlgebraicSimplificationDynamic.java | 6 +--
.../runtime/instructions/InstructionUtils.java | 41 ++-------------
.../instructions/cp/CtableCPInstruction.java | 27 ++++++----
.../instructions/cp/DataGenCPInstruction.java | 56 ++++++++++++++------
.../instructions/cp/FunctionCallCPInstruction.java | 7 ++-
.../instructions/cp/ReshapeCPInstruction.java | 9 ++++
.../instructions/spark/RandSPInstruction.java | 10 +++-
.../sysds/runtime/lineage/LineageItemUtils.java | 59 +++++++++++++++++-----
.../functions/lineage/LineageTraceParforTest.java | 34 ++++++++-----
...aceParfor4.dml => LineageTraceParforKmeans.dml} | 3 +-
...aceParfor4.dml => LineageTraceParforSteplm.dml} | 0
15 files changed, 211 insertions(+), 127 deletions(-)
diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 23482da..96591c6 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -60,8 +60,8 @@ m_kmeans = function(Matrix[Double] X, Integer k = 0, Integer
runs = 10, Integer
print ("Taking data samples for initialization...");
- [sample_maps, samples_vs_runs_map, sample_block_size] =
- get_sample_maps (num_records, num_runs, num_centroids *
avg_sample_size_per_centroid);
+ [sample_maps, samples_vs_runs_map, sample_block_size] = get_sample_maps(
+ num_records, num_runs, num_centroids * avg_sample_size_per_centroid);
is_row_in_samples = rowSums (sample_maps);
X_samples = sample_maps %*% X;
@@ -230,7 +230,7 @@ get_sample_maps = function (int num_records, int
num_samples, int approx_sample_
# Replace all sample record ids over "num_records" (i.e. out of range) by
"num_records + 1":
is_sample_rec_id_within_range = (sample_rec_ids <= num_records);
sample_rec_ids = sample_rec_ids * is_sample_rec_id_within_range
- + (num_records + 1) * (1 - is_sample_rec_id_within_range);
+ + (num_records + 1) * (1 - is_sample_rec_id_within_range);
# Rearrange all samples (and their out-of-range indicators) into one
column-vector:
sample_rec_ids = matrix (sample_rec_ids, rows = num_rows, cols = 1, byrow
= FALSE);
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index d693b7f..2d66e81 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -206,6 +206,15 @@ public class Types
MULT2, MINUS1_MULT, MINUS_RIGHT,
POW2, SUBTRACT_NZ;
+
+ public boolean isScalarOutput() {
+ return this == CAST_AS_SCALAR
+ || this == NROW || this == NCOL
+ || this == LENGTH || this == EXISTS
+ || this == IQM || this == LINEAGE
+ || this == MEDIAN;
+ }
+
@Override
public String toString() {
switch(this) {
@@ -244,7 +253,7 @@ public class Types
case "ucumk+": return CUMSUM;
case "ucumk+*": return CUMSUMPROD;
case "*2": return MULT2;
- case "!": return OpOp1.NOT;
+ case "!": return NOT;
case "^2": return POW2;
default: return
valueOf(opcode.toUpperCase());
}
@@ -354,12 +363,12 @@ public class Types
}
}
- public static OpOp3 valueOfCode(String code) {
- switch(code) {
- case "cm": return OpOp3.MOMENT;
- case "+*": return OpOp3.PLUS_MULT;
- case "-*": return OpOp3.MINUS_MULT;
- default: return
OpOp3.valueOf(code.toUpperCase());
+ public static OpOp3 valueOfByOpcode(String opcode) {
+ switch(opcode) {
+ case "cm": return MOMENT;
+ case "+*": return PLUS_MULT;
+ case "-*": return MINUS_MULT;
+ default: return valueOf(opcode.toUpperCase());
}
}
}
@@ -394,11 +403,21 @@ public class Types
@Override
public String toString() {
switch(this) {
- case TRANS: return "t";
+ case DIAG: return "rdiag";
+ case TRANS: return "r'";
case RESHAPE: return "rshape";
default: return name().toLowerCase();
}
}
+
+ public static ReOrgOp valueOfByOpcode(String opcode) {
+ switch(opcode) {
+ case "rdiag": return DIAG;
+ case "r'": return TRANS;
+ case "rshape": return RESHAPE;
+ default: return
valueOf(opcode.toUpperCase());
+ }
+ }
}
public enum ParamBuiltinOp {
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 29e0c0e..4334384 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1448,15 +1448,11 @@ public class Recompiler
* @param scalarsOnly if true, replace only scalar variables but no
matrix operations;
* if false, apply full literal replacement
*/
- public static void rReplaceLiterals( Hop hop, ExecutionContext ec,
boolean scalarsOnly )
- {
- //public interface
+ public static void rReplaceLiterals( Hop hop, ExecutionContext ec,
boolean scalarsOnly ) {
LiteralReplacement.rReplaceLiterals(hop, ec, scalarsOnly);
}
- public static void rReplaceLiterals( Hop hop, LocalVariableMap vars,
boolean scalarsOnly )
- {
- //public interface
+ public static void rReplaceLiterals( Hop hop, LocalVariableMap vars,
boolean scalarsOnly ) {
LiteralReplacement.rReplaceLiterals(hop, new
ExecutionContext(vars), scalarsOnly);
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 9e73fcc..30f66f4 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -583,6 +583,10 @@ public class HopRewriteUtils
return createReorg(input, ReOrgOp.TRANS);
}
+ public static ReorgOp createReorg(Hop input, String rop) {
+ return createReorg(input, ReOrgOp.valueOfByOpcode(rop));
+ }
+
public static ReorgOp createReorg(Hop input, ReOrgOp rop) {
ReorgOp reorg = new ReorgOp(input.getName(),
input.getDataType(), input.getValueType(), rop, input);
reorg.setBlocksize(input.getBlocksize());
@@ -604,22 +608,19 @@ public class HopRewriteUtils
return createUnary(input, OpOp1.valueOfByOpcode(type));
}
- public static UnaryOp createUnary(Hop input, OpOp1 type)
- {
- DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR :
+ public static UnaryOp createUnary(Hop input, OpOp1 type) {
+ DataType dt = type.isScalarOutput() ? DataType.SCALAR :
(type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX :
input.getDataType();
ValueType vt = (type==OpOp1.CAST_AS_MATRIX) ? ValueType.FP64 :
input.getValueType();
UnaryOp unary = new UnaryOp(input.getName(), dt, vt, type,
input);
unary.setBlocksize(input.getBlocksize());
- if( type == OpOp1.CAST_AS_SCALAR || type ==
OpOp1.CAST_AS_MATRIX ) {
- int dim = (type==OpOp1.CAST_AS_SCALAR) ? 0 : 1;
+ if( type.isScalarOutput() || type == OpOp1.CAST_AS_MATRIX ) {
+ int dim = type.isScalarOutput() ? 0 : 1;
int blksz = (type==OpOp1.CAST_AS_SCALAR) ? 0 :
ConfigurationManager.getBlocksize();
setOutputParameters(unary, dim, dim, blksz, -1);
}
-
copyLineNumbers(input, unary);
- unary.refreshSizeInformation();
-
+ unary.refreshSizeInformation();
return unary;
}
@@ -681,7 +682,6 @@ public class HopRewriteUtils
mmult.setBlocksize(left.getBlocksize());
copyLineNumbers(left, mmult);
mmult.refreshSizeInformation();
-
return mmult;
}
@@ -690,7 +690,6 @@ public class HopRewriteUtils
pbop.setBlocksize(input.getBlocksize());
copyLineNumbers(input, pbop);
pbop.refreshSizeInformation();
-
return pbop;
}
@@ -774,23 +773,29 @@ public class HopRewriteUtils
return datagen;
}
- public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop
mright, String opcode) {
- return createTernaryOp(mleft, smid, mright,
OpOp3.valueOfCode(opcode));
+ public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright,
String opcode) {
+ return createTernary(mleft, smid, mright,
OpOp3.valueOfByOpcode(opcode));
}
- public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop
mright, OpOp3 op) {
+ public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright,
OpOp3 op) {
//NOTe: for ifelse it's sufficient to check mright as
smid==mright
- System.out.println(mleft.getDataType()+" "+smid.getDataType()+"
"+mright.getDataType());
DataType dt = (op == OpOp3.IFELSE) ? mright.getDataType() :
DataType.MATRIX;
ValueType vt = (op == OpOp3.IFELSE) ? mright.getValueType() :
ValueType.FP64;
TernaryOp ternOp = new TernaryOp("tmp", dt, vt, op, mleft,
smid, mright);
- if( dt == DataType.MATRIX )
- ternOp.setBlocksize(mleft.getBlocksize());
+ ternOp.setBlocksize(Math.max(mleft.getBlocksize(),
mright.getBlocksize()));
copyLineNumbers(mleft, ternOp);
ternOp.refreshSizeInformation();
return ternOp;
}
+ public static TernaryOp createTernary(Hop in1, Hop in2, Hop in3, Hop
in4, Hop in5, OpOp3 op) {
+ TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX,
ValueType.FP64, op, in1, in2, in3, in4, in5);
+ ternOp.setBlocksize(Math.max(in1.getBlocksize(),
in2.getBlocksize()));
+ copyLineNumbers(in1, ternOp);
+ ternOp.refreshSizeInformation();
+ return ternOp;
+ }
+
public static Hop createComputeNnz(Hop input) {
//nnz = sum(A != 0) -> later rewritten to meta-data operation
return createSum(createBinary(input, new LiteralOp(0),
OpOp2.NOTEQUAL));
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e929e0a..1929315 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2285,7 +2285,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
Hop smid = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
- left :
HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
+ left :
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
}
//pattern (b) s*Y + X -> X +* sY
@@ -2297,7 +2297,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
Hop smid = left.getInput().get(
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
Hop mright = left.getInput().get(
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
- right :
HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
+ right :
HopRewriteUtils.createTernary(right, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");
}
//pattern (c) X - s*Y -> X -* sY
@@ -2309,7 +2309,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
Hop smid = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
- left :
HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
+ left :
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.MINUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index f1c8dc6..740d821 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -24,14 +24,7 @@ import java.util.StringTokenizer;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.lops.AppendM;
-import org.apache.sysds.lops.BinaryM;
-import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.lops.MapMult;
-import org.apache.sysds.lops.MapMultChain;
-import org.apache.sysds.lops.PMMJ;
-import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropyR;
import org.apache.sysds.lops.WeightedDivMM;
@@ -239,36 +232,12 @@ public class InstructionUtils
Builtin.BuiltinCode bfc =
Builtin.String2BuiltinCode.get(opcode);
return (bfc != null);
}
-
- /**
- * Evaluates if at least one instruction of the given instruction set
- * used the distributed cache; this call can also be used for individual
- * instructions.
- *
- * @param str instruction set
- * @return true if at least one instruction uses distributed cache
- */
- public static boolean isDistributedCacheUsed(String str)
- {
- String[] parts = str.split(Instruction.INSTRUCTION_DELIM);
- for(String inst : parts)
- {
- String opcode = getOpCode(inst);
- if( opcode.equalsIgnoreCase(AppendM.OPCODE)
- || opcode.equalsIgnoreCase(MapMult.OPCODE)
- || opcode.equalsIgnoreCase(MapMultChain.OPCODE)
- || opcode.equalsIgnoreCase(PMMJ.OPCODE)
- || opcode.equalsIgnoreCase(UAggOuterChain.OPCODE)
- || opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE)
- || isDistQuaternaryOpcode( opcode ) //multiple
quaternary opcodes
- || BinaryM.isOpcode( opcode ) ) //multiple binary
opcodes
- {
- return true;
- }
- }
- return false;
+
+ public static boolean isUnaryMetadata(String opcode) {
+ return opcode != null
+ && (opcode.equals("nrow") || opcode.equals("ncol"));
}
-
+
public static AggregateUnaryOperator
parseBasicAggregateUnaryOperator(String opcode) {
return parseBasicAggregateUnaryOperator(opcode, 1);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 77625f4..4869e3e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -20,22 +20,23 @@
package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.lops.Ctable;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
public class CtableCPInstruction extends ComputationCPInstruction {
- private final String _outDim1;
- private final String _outDim2;
- private final boolean _dim1Literal;
- private final boolean _dim2Literal;
+ private final CPOperand _outDim1;
+ private final CPOperand _outDim2;
private final boolean _isExpand;
private final boolean _ignoreZeros;
@@ -43,10 +44,8 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
String outputDim1, boolean dim1Literal, String
outputDim2, boolean dim2Literal, boolean isExpand,
boolean ignoreZeros, String opcode, String istr) {
super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
- _outDim1 = outputDim1;
- _dim1Literal = dim1Literal;
- _outDim2 = outputDim2;
- _dim2Literal = dim2Literal;
+ _outDim1 = new CPOperand(outputDim1, ValueType.FP64,
DataType.SCALAR, dim1Literal);
+ _outDim2 = new CPOperand(outputDim2, ValueType.FP64,
DataType.SCALAR, dim2Literal);
_isExpand = isExpand;
_ignoreZeros = ignoreZeros;
}
@@ -98,8 +97,8 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
Ctable.OperationTypes ctableOp = findCtableOperation();
ctableOp = _isExpand ?
Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
- long outputDim1 = (_dim1Literal ? (long)
Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.FP64,
false)).getLongValue());
- long outputDim2 = (_dim2Literal ? (long)
Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.FP64,
false)).getLongValue());
+ long outputDim1 = ec.getScalarInput(_outDim1).getLongValue();
+ long outputDim2 = ec.getScalarInput(_outDim2).getLongValue();
boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 !=
-1);
if ( outputDimsKnown ) {
@@ -178,4 +177,12 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
ec.setMatrixOutput(output.getName(), resultBlock);
}
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ LineageItem[] linputs = !(_outDim1.getName().equals("-1") &&
_outDim2.getName().equals("-1")) ?
+ LineageItemUtils.getLineage(ec, input1, input2, input3,
_outDim1, _outDim2) :
+ LineageItemUtils.getLineage(ec, input1, input2, input3);
+ return Pair.of(output.getName(), new LineageItem(getOpcode(),
linputs));
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index 11f4e8e..4aa9660 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -130,14 +130,16 @@ public class DataGenCPInstruction extends
UnaryCPInstruction {
}
public long getRows() {
- return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1;
+ return rows.isLiteral() ?
UtilFunctions.parseToLong(rows.getName()) : -1;
}
public long getCols() {
- return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1;
+ return cols.isLiteral() ?
UtilFunctions.parseToLong(cols.getName()) : -1;
}
- public String getDims() { return dims.getName(); }
+ public String getDims() {
+ return dims.getName();
+ }
public int getBlocksize() {
return blocksize;
@@ -172,15 +174,15 @@ public class DataGenCPInstruction extends
UnaryCPInstruction {
}
public long getFrom() {
- return seq_from.isLiteral() ?
Long.parseLong(seq_from.getName()) : -1;
+ return seq_from.isLiteral() ?
UtilFunctions.parseToLong(seq_from.getName()) : -1;
}
public long getTo() {
- return seq_to.isLiteral() ? Long.parseLong(seq_to.getName()) :
-1;
+ return seq_to.isLiteral() ?
UtilFunctions.parseToLong(seq_to.getName()) : -1;
}
public long getIncr() {
- return seq_incr.isLiteral() ?
Long.parseLong(seq_incr.getName()) : -1;
+ return seq_incr.isLiteral() ?
UtilFunctions.parseToLong(seq_incr.getName()) : -1;
}
public static DataGenCPInstruction parseInstruction(String str)
@@ -385,16 +387,40 @@ public class DataGenCPInstruction extends
UnaryCPInstruction {
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
String tmpInstStr = instString;
- if (getSeed() == DataGenOp.UNSPECIFIED_SEED) {
- //generate pseudo-random seed (because not specified)
- if (runtimeSeed == null)
- runtimeSeed = (minValue == maxValue && sparsity
== 1) ?
- DataGenOp.UNSPECIFIED_SEED :
DataGenOp.generateRandomSeed();
- int position = (method == OpOpDG.RAND) ?
SEED_POSITION_RAND :
- (method == OpOpDG.SAMPLE) ?
SEED_POSITION_SAMPLE : 0;
- tmpInstStr = position != 0 ?
InstructionUtils.replaceOperand(
- tmpInstStr, position,
String.valueOf(runtimeSeed)) : tmpInstStr;
+
+ switch(method) {
+ case RAND:
+ case SAMPLE: {
+ if (getSeed() == DataGenOp.UNSPECIFIED_SEED) {
+ //generate pseudo-random seed (because
not specified)
+ if (runtimeSeed == null)
+ runtimeSeed = (minValue ==
maxValue && sparsity == 1) ?
+
DataGenOp.UNSPECIFIED_SEED : DataGenOp.generateRandomSeed();
+ int position = (method == OpOpDG.RAND)
? SEED_POSITION_RAND :
+ (method == OpOpDG.SAMPLE) ?
SEED_POSITION_SAMPLE : 0;
+ tmpInstStr = position != 0 ?
InstructionUtils.replaceOperand(
+ tmpInstStr, position,
String.valueOf(runtimeSeed)) : tmpInstStr;
+ }
+ tmpInstStr = replaceNonLiteral(tmpInstStr,
rows, 2, ec);
+ tmpInstStr = replaceNonLiteral(tmpInstStr,
cols, 3, ec);
+ break;
+ }
+ case SEQ: {
+ tmpInstStr = replaceNonLiteral(tmpInstStr,
seq_from, 5, ec);
+ tmpInstStr = replaceNonLiteral(tmpInstStr,
seq_to, 6, ec);
+ tmpInstStr = replaceNonLiteral(tmpInstStr,
seq_incr, 7, ec);
+ break;
+ }
+ default:
+ throw new DMLRuntimeException("Unsupported
datagen op: "+method);
}
return Pair.of(output.getName(), new LineageItem(tmpInstStr,
getOpcode()));
}
+
+ private static String replaceNonLiteral(String inst, CPOperand op, int
pos, ExecutionContext ec) {
+ if( !op.isLiteral() )
+ inst = InstructionUtils.replaceOperand(inst, pos,
+ new
CPOperand(ec.getScalarInput(op)).getLineageLiteral());
+ return inst;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 0f0951e..f00f42d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -155,8 +155,11 @@ public class FunctionCallCPInstruction extends
CPInstruction {
functionVariables.put(currFormalParam.getName(), value);
//map lineage to function arguments
- if( lineage != null )
- lineage.set(currFormalParam.getName(),
ec.getLineageItem(input));
+ if( lineage != null ) {
+ LineageItem litem = ec.getLineageItem(input);
+ lineage.set(currFormalParam.getName(),
(litem!=null) ?
+ litem :
ec.getLineage().getOrCreate(input));
+ }
}
// Pin the input variables so that they do not get deleted
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
index 6262c50..8a3c001 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.cp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -26,6 +27,8 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.LibTensorReorg;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -104,4 +107,10 @@ public class ReshapeCPInstruction extends
UnaryCPInstruction {
ec.releaseMatrixInput(input1.getName());
}
}
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return Pair.of(output.getName(), new LineageItem(getOpcode(),
+ LineageItemUtils.getLineage(ec, input1, _opRows,
_opCols, _opDims, _opByRow)));
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
index a2058b7..ef40773 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
@@ -164,11 +164,11 @@ public class RandSPInstruction extends UnarySPInstruction
{
}
public long getRows() {
- return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1;
+ return rows.isLiteral() ?
UtilFunctions.parseToLong(rows.getName()) : -1;
}
public long getCols() {
- return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1;
+ return cols.isLiteral() ?
UtilFunctions.parseToLong(cols.getName()) : -1;
}
public int getBlocksize() {
@@ -1011,6 +1011,12 @@ public class RandSPInstruction extends
UnarySPInstruction {
(_method == OpOpDG.SAMPLE) ?
SEED_POSITION_SAMPLE : 0;
tmpInstStr = InstructionUtils.replaceOperand(
tmpInstStr, position,
String.valueOf(runtimeSeed));
+ if( !rows.isLiteral() )
+ tmpInstStr =
InstructionUtils.replaceOperand(tmpInstStr, 2,
+ new
CPOperand(ec.getScalarInput(rows)).getLineageLiteral());
+ if( !cols.isLiteral() )
+ tmpInstStr =
InstructionUtils.replaceOperand(tmpInstStr, 3,
+ new
CPOperand(ec.getScalarInput(cols)).getLineageLiteral());
}
return Pair.of(output.getName(), new LineageItem(tmpInstStr,
getOpcode()));
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 39a8c2a..b75baee 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -31,6 +31,9 @@ import
org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
@@ -267,7 +270,9 @@ public class LineageItemUtils {
switch (ctype) {
case AggregateUnary: {
Hop input =
operands.get(item.getInputs()[0].getId());
- Hop aggunary =
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+ Hop aggunary =
InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
operands.put(item.getId(), aggunary);
break;
}
@@ -279,9 +284,15 @@ public class LineageItemUtils {
break;
}
case Reorg: {
- Hop input =
operands.get(item.getInputs()[0].getId());
- Hop reorg =
HopRewriteUtils.createReorg(input, ReOrgOp.TRANS);
-
operands.put(item.getId(), reorg);
+
operands.put(item.getId(), HopRewriteUtils.createReorg(
+
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+ break;
+ }
+ case Reshape: {
+ ArrayList<Hop> inputs =
new ArrayList<>();
+ for(int i=0; i<5; i++)
+
inputs.add(operands.get(item.getInputs()[i].getId()));
+
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs,
ReOrgOp.RESHAPE));
break;
}
case Binary: {
@@ -303,12 +314,27 @@ public class LineageItemUtils {
break;
}
case Ternary: {
-
operands.put(item.getId(), HopRewriteUtils.createTernaryOp(
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
break;
}
+ case Ctable: { //e.g., ctable
+ if(
item.getInputs().length==3 )
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()),
+
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+ else if(
item.getInputs().length==5 )
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()),
+
operands.get(item.getInputs()[2].getId()),
+
operands.get(item.getInputs()[3].getId()),
+
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+ break;
+ }
case BuiltinNary: {
String opcode =
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
operands.put(item.getId(), HopRewriteUtils.createNary(
@@ -331,8 +357,13 @@ public class LineageItemUtils {
operands.put(item.getId(), aggunary);
break;
}
- case Variable: { //cpvar, write
-
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+ case Variable: {
+ if(
item.getOpcode().startsWith("cast") )
+
operands.put(item.getId(), HopRewriteUtils.createUnary(
+
operands.get(item.getInputs()[0].getId()),
+
OpOp1.valueOfByOpcode(item.getOpcode())));
+ else //cpvar, write
+
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
break;
}
default:
@@ -358,6 +389,12 @@ public class LineageItemUtils {
operands.put(item.getId(), constructIndexingOp(item, operands));
break;
}
+ case GAppend: {
+
operands.put(item.getId(), HopRewriteUtils.createBinary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+ break;
+ }
default:
throw new
DMLRuntimeException("Unsupported instruction "
+ "type: " +
stype.name() + " (" + item.getOpcode() + ").");
@@ -482,18 +519,16 @@ public class LineageItemUtils {
}
private static Hop constructIndexingOp(LineageItem item, Map<Long, Hop>
operands) {
- //TODO fix
+ Hop input = operands.get(item.getInputs()[0].getId());
if( "rightIndex".equals(item.getOpcode()) )
- return HopRewriteUtils.createIndexingOp(
- operands.get(item.getInputs()[0].getId()),
//input
+ return HopRewriteUtils.createIndexingOp(input,
operands.get(item.getInputs()[1].getId()), //rl
operands.get(item.getInputs()[2].getId()), //ru
operands.get(item.getInputs()[3].getId()), //cl
operands.get(item.getInputs()[4].getId())); //cu
else if( "leftIndex".equals(item.getOpcode())
|| "mapLeftIndex".equals(item.getOpcode()) )
- return HopRewriteUtils.createLeftIndexingOp(
- operands.get(item.getInputs()[0].getId()),
//input
+ return HopRewriteUtils.createLeftIndexingOp(input,
operands.get(item.getInputs()[1].getId()), //rhs
operands.get(item.getInputs()[2].getId()), //rl
operands.get(item.getInputs()[3].getId()), //ru
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
index 50443c1..d100a4d 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
@@ -44,7 +44,8 @@ public class LineageTraceParforTest extends AutomatedTestBase
{
protected static final String TEST_NAME1 = "LineageTraceParfor1";
//rand - matrix result - local parfor
protected static final String TEST_NAME2 = "LineageTraceParfor2";
//rand - matrix result - remote spark parfor
protected static final String TEST_NAME3 = "LineageTraceParfor3";
//rand - matrix result - remote spark parfor
- protected static final String TEST_NAME4 = "LineageTraceParfor4";
//rand - steplm (stackoverflow error)
+ protected static final String TEST_NAME4 = "LineageTraceParforSteplm";
//rand - steplm
+ protected static final String TEST_NAME5 = "LineageTraceParforKmeans";
//rand - kmeans
protected String TEST_CLASS_DIR = TEST_DIR +
LineageTraceParforTest.class.getSimpleName() + "/";
@@ -61,6 +62,7 @@ public class LineageTraceParforTest extends AutomatedTestBase
{
addTestConfiguration( TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
addTestConfiguration( TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) );
addTestConfiguration( TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) );
+ addTestConfiguration( TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) );
}
@Test
@@ -113,16 +115,25 @@ public class LineageTraceParforTest extends
AutomatedTestBase {
testLineageTraceParFor(32, TEST_NAME3);
}
-// TODO additional fixes needed for steplm
-// @Test
-// public void testLineageTraceParFor4_8() {
-// testLineageTraceParFor(8, TEST_NAME4);
-// }
-//
-// @Test
-// public void testLineageTraceParFor4_32() {
-// testLineageTraceParFor(32, TEST_NAME4);
-// }
+ @Test
+ public void testLineageTraceSteplm_8() {
+ testLineageTraceParFor(8, TEST_NAME4);
+ }
+
+ @Test
+ public void testLineageTraceSteplm_32() {
+ testLineageTraceParFor(32, TEST_NAME4);
+ }
+
+ @Test
+ public void testLineageTraceKMeans_8() {
+ testLineageTraceParFor(8, TEST_NAME5);
+ }
+
+ @Test
+ public void testLineageTraceKmeans_32() {
+ testLineageTraceParFor(32, TEST_NAME5);
+ }
private void testLineageTraceParFor(int ncol, String testname) {
try {
@@ -146,7 +157,6 @@ public class LineageTraceParforTest extends
AutomatedTestBase {
//get lineage and generate program
String Rtrace = readDMLLineageFromHDFS("R");
- System.out.println(Rtrace);
LineageItem R = LineageParser.parseLineageTrace(Rtrace);
Data ret = LineageItemUtils.computeByLineage(R);
diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml
b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
similarity index 94%
copy from src/test/scripts/functions/lineage/LineageTraceParfor4.dml
copy to src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
index 576182b..215cb5b 100644
--- a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml
+++ b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
@@ -20,7 +20,6 @@
#-------------------------------------------------------------
X = rand(rows=$2, cols=$3, seed=7);
-Y = rand(rows=nrow(X), cols=1, seed=2)
-X = steplm(X=X, y=Y)
+X = kmeans(X=X, k=4)
write(X, $1);
diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml
b/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml
similarity index 100%
rename from src/test/scripts/functions/lineage/LineageTraceParfor4.dml
rename to src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml