Repository: systemml Updated Branches: refs/heads/master 827d73bd5 -> 430c04d59
[SYSTEMML-2475] Fix matrix/frame left indexing into list data types This patch fixes the missing left indexing support for frames and matrices into lists. Furthermore, this also includes a robustness fix for inferring the output data type of builtin functions when the target is a list left indexing (which propagated incorrectly to the source). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f2c0d13e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f2c0d13e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f2c0d13e Branch: refs/heads/master Commit: f2c0d13e2d785ae9143f38cc59b06f14b7dd4fc8 Parents: 827d73b Author: Matthias Boehm <mboe...@gmail.com> Authored: Mon Jul 30 21:58:37 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Jul 30 21:58:37 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/lops/PartialAggregate.java | 8 +++ .../org/apache/sysml/parser/DMLTranslator.java | 55 +++++++++----------- .../context/ExecutionContext.java | 4 ++ .../cp/ListIndexingCPInstruction.java | 9 ++++ 4 files changed, 47 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/lops/PartialAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java b/src/main/java/org/apache/sysml/lops/PartialAggregate.java index 8358a1d..fa7b966 100644 --- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java @@ -377,6 +377,14 @@ public class PartialAggregate extends Lop } } + case SumProduct: { + switch( dir ) { + case RowCol: return "ua+*"; + case Row: return "uar+*"; + case Col: return "uac+*"; + } + } + case Max: { if( dir == DirectionTypes.RowCol ) return "uamax"; http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index bdfdf8f..e9b643e 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2325,10 +2325,7 @@ public class DMLTranslator } Hop currBuiltinOp = null; - - if (target == null) { - target = createTarget(source); - } + target = (target == null) ? createTarget(source) : target; // Construct the hop based on the type of Builtin function switch (source.getOpCode()) { @@ -2344,15 +2341,15 @@ public class DMLTranslator case COLMEAN: case COLPROD: case COLVAR: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Col, expr); break; case COLSD: // colStdDevs = sqrt(colVariances) - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.VAR, Direction.Col, expr); - currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); break; @@ -2362,25 +2359,25 @@ public class DMLTranslator case ROWMEAN: case ROWPROD: case ROWVAR: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.valueOf(source.getOpCode().name().substring(3)), Direction.Row, expr); break; case ROWINDEXMAX: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr); break; case ROWINDEXMIN: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MININDEX, Direction.Row, expr); break; case ROWSD: // rowStdDevs = sqrt(rowVariances) - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.VAR, Direction.Row, expr); - currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); break; @@ -2409,38 +2406,38 @@ public class DMLTranslator break; case EXISTS: - currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.EXISTS, expr); break; case SUM: case PROD: case VAR: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr); break; case MEAN: if ( expr2 == null ) { // example: x = mean(Y); - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.MEAN, Direction.RowCol, expr); } else { // example: x = mean(Y,W); // stable weighted mean is implemented by using centralMoment with order = 0 Hop orderHop = new LiteralOp(0); - currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp=new TernaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp3.MOMENT, expr, expr2, orderHop); } break; case SD: // stdDev = sqrt(variance) - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.VAR, Direction.RowCol, expr); HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp); - currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), + currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); break; @@ -2448,7 +2445,7 @@ public class DMLTranslator case MAX: //construct AggUnary for min(X) but BinaryOp for min(X,Y) and NaryOp for min(X,Y,Z) currBuiltinOp = (expr2 == null) ? - new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), + new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr) : (source.getAllExpr().length == 2) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), @@ -2480,14 +2477,14 @@ public class DMLTranslator break; case TRACE: - currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.TRACE, Direction.RowCol, expr); break; case TRANS: case DIAG: case REV: - currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), + currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX, target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr); break; @@ -2546,16 +2543,16 @@ public class DMLTranslator } break; - //data type casts + //data type casts case CAST_AS_SCALAR: currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr); break; case CAST_AS_MATRIX: currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr); - break; + break; case CAST_AS_FRAME: currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr); - break; + break; //value type casts case CAST_AS_DOUBLE: @@ -2725,7 +2722,7 @@ public class DMLTranslator if( op == null ) throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue()); - currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2); + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), op, expr, expr2); ((BinaryOp)currBuiltinOp).setOuterVectorOperation(true); //flag op as specific outer vector operation currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims break; @@ -2735,21 +2732,21 @@ public class DMLTranslator ArrayList<Hop> inHops1 = new ArrayList<>(); inHops1.add(expr); inHops1.add(expr2); - currBuiltinOp = new DnnOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOpDnn.valueOf(source.getOpCode().name()), inHops1); setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp); break; } case AVG_POOL: case MAX_POOL: { - currBuiltinOp = new DnnOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForPoolingForwardIM2COL(expr, source, 1, hops)); setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp); break; } case AVG_POOL_BACKWARD: case MAX_POOL_BACKWARD: { - currBuiltinOp = new DnnOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOpPoolingCOL2IM(expr, source, 1, hops)); setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp); break; @@ -2757,7 +2754,7 @@ public class DMLTranslator case CONV2D: case CONV2D_BACKWARD_FILTER: case CONV2D_BACKWARD_DATA: { - currBuiltinOp = new DnnOp(target.getName(), target.getDataType(), target.getValueType(), + currBuiltinOp = new DnnOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOpDnn.valueOf(source.getOpCode().name()), getALHopsForConvOp(expr, source, 1, hops)); setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp); break; http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 2e0addf..88ec092 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -225,6 +225,10 @@ public class ExecutionContext { return (FrameObject) dat; } + public CacheableData<?> getCacheableData(CPOperand input) { + return getCacheableData(input.getName()); + } + public CacheableData<?> getCacheableData(String varname) { Data dat = getVariable(varname); //error handling if non existing or no matrix http://git-wip-us.apache.org/repos/asf/systemml/blob/f2c0d13e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java index f41601f..d22d9a8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListIndexingCPInstruction.java @@ -23,6 +23,7 @@ import org.apache.sysml.lops.LeftIndex; import org.apache.sysml.lops.RightIndex; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; public final class ListIndexingCPInstruction extends IndexingCPInstruction { @@ -76,6 +77,14 @@ public final class ListIndexingCPInstruction extends IndexingCPInstruction { else ec.setVariable(output.getName(), lin.copy().set((int)rl.getLongValue()-1, scalar)); } + else if( input2.getDataType().isMatrix() ) { //LIST <- MATRIX/FRAME + CacheableData<?> dat = ec.getCacheableData(input2); + dat.enableCleanup(false); + if( rl.getValueType()==ValueType.STRING ) + ec.setVariable(output.getName(), lin.copy().set(rl.getStringValue(), dat)); + else + ec.setVariable(output.getName(), lin.copy().set((int)rl.getLongValue()-1, dat)); + } else { throw new DMLRuntimeException("Unsupported list " + "left indexing rhs type: "+input2.getDataType().name());