Repository: incubator-systemml Updated Branches: refs/heads/master d69fdfe45 -> 117ea480d
[MINOR] Graceful value type casting of scalar function args, cleanup Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/53fe1ae6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/53fe1ae6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/53fe1ae6 Branch: refs/heads/master Commit: 53fe1ae68ab3b5024ead0d258a213f3e4f392616 Parents: d69fdfe Author: Matthias Boehm <[email protected]> Authored: Tue Apr 18 01:37:55 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Apr 19 18:46:28 2017 -0700 ---------------------------------------------------------------------- .../sysml/debug/DMLDebuggerFunctions.java | 31 ++--------------- .../sysml/hops/ipa/InterProceduralAnalysis.java | 17 ++------- .../sysml/hops/rewrite/HopRewriteUtils.java | 27 +++------------ .../context/ExecutionContext.java | 29 ++-------------- .../cp/FunctionCallCPInstruction.java | 11 +++++- .../instructions/cp/ScalarObjectFactory.java | 36 ++++++++++++++++++-- 6 files changed, 58 insertions(+), 93 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/debug/DMLDebuggerFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/debug/DMLDebuggerFunctions.java b/src/main/java/org/apache/sysml/debug/DMLDebuggerFunctions.java index d852747..09bb0d6 100644 --- a/src/main/java/org/apache/sysml/debug/DMLDebuggerFunctions.java +++ b/src/main/java/org/apache/sysml/debug/DMLDebuggerFunctions.java @@ -34,14 +34,9 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.MRJobInstruction; -import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.BreakPointInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction; -import org.apache.sysml.runtime.instructions.cp.Data; -import org.apache.sysml.runtime.instructions.cp.DoubleObject; -import org.apache.sysml.runtime.instructions.cp.IntObject; -import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.instructions.cp.BreakPointInstruction.BPINSTRUCTION_STATUS; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -269,28 +264,8 @@ public class DMLDebuggerFunctions { if (variables != null && !variables.keySet().isEmpty()) { if (variables.get(varname) != null) { if (variables.get(varname).getDataType() == DataType.SCALAR) { - Data value; - switch(variables.get(varname).getValueType()) { - case DOUBLE: - double d = Double.parseDouble(args[1]); - value = (ScalarObject) new DoubleObject(d); - break; - case INT: - long i = Long.parseLong(args[1]); - value = (ScalarObject) new IntObject(i); - break; - case BOOLEAN: - boolean b = Boolean.parseBoolean(args[1]); - value = (ScalarObject) new BooleanObject(b); - break; - case STRING: - value = (ScalarObject) new StringObject(args[1]); - break; - default: - System.err.println("Invalid scalar value type."); - return; - } - variables.put(varname, value); + variables.put(varname, ScalarObjectFactory + .createScalarObject(variables.get(varname).getValueType(), args[1])); System.out.println(varname + " = " + variables.get(varname).toString()); } else http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 4cea2e2..b0ebfaa 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -70,12 +70,9 @@ import org.apache.sysml.parser.WhileStatement; import org.apache.sysml.parser.WhileStatementBlock; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.Data; -import org.apache.sysml.runtime.instructions.cp.DoubleObject; -import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.udf.lib.DeNaNWrapper; @@ -695,16 +692,8 @@ public class InterProceduralAnalysis //always propagate scalar literals into functions //(for multiple calls, literal equivalence already checked) if( input instanceof LiteralOp ) { - LiteralOp lit = (LiteralOp)input; - ScalarObject scalar = null; - switch(input.getValueType()) { - case DOUBLE: scalar = new DoubleObject(lit.getDoubleValue()); break; - case INT: scalar = new IntObject(lit.getLongValue()); break; - case BOOLEAN: scalar = new BooleanObject(lit.getBooleanValue()); break; - case STRING: scalar = new StringObject(lit.getStringValue()); break; - default: //do nothing - } - vars.put(dat.getName(), scalar); + vars.put(dat.getName(), ScalarObjectFactory + .createScalarObject(input.getValueType(), (LiteralOp)input)); } //propagate scalar variables into functions if called once //and input scalar is existing variable in symbol table http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index ba5d1ab..f3baec1 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -56,12 +56,8 @@ import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Statement; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.cp.BooleanObject; -import org.apache.sysml.runtime.instructions.cp.DoubleObject; -import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.UtilFunctions; @@ -196,26 +192,13 @@ public class HopRewriteUtils public static ScalarObject getScalarObject( LiteralOp op ) { - ScalarObject ret = null; - - try - { - switch( op.getValueType() ) - { - case DOUBLE: ret = new DoubleObject(op.getDoubleValue()); break; - case INT: ret = new IntObject(op.getLongValue()); break; - case BOOLEAN: ret = new BooleanObject(op.getBooleanValue()); break; - case STRING: ret = new StringObject(op.getStringValue()); break; - default: - throw new DMLRuntimeException("Invalid scalar object value type: "+op.getValueType()); - } + try { + return ScalarObjectFactory + .createScalarObject(op.getValueType(), op); } - catch(Exception ex) - { + catch(Exception ex) { throw new RuntimeException("Failed to create scalar object for constant. Continue.", ex); } - - return ret; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 6455add..b6e1830 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -36,13 +36,10 @@ import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.instructions.Instruction; -import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.Data; -import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; -import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.instructions.gpu.context.GPUObject; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -53,7 +50,6 @@ import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.Pair; import org.apache.sysml.runtime.util.MapReduceTool; -import org.apache.sysml.runtime.util.UtilFunctions; public class ExecutionContext @@ -347,31 +343,12 @@ public class ExecutionContext throws DMLRuntimeException { if ( isLiteral ) { - switch (vt) { - case INT: - long intVal = UtilFunctions.parseToLong(name); - IntObject intObj = new IntObject(intVal); - return intObj; - case DOUBLE: - double doubleVal = Double.parseDouble(name); - DoubleObject doubleObj = new DoubleObject(doubleVal); - return doubleObj; - case BOOLEAN: - Boolean boolVal = Boolean.parseBoolean(name); - BooleanObject boolObj = new BooleanObject(boolVal); - return boolObj; - case STRING: - StringObject stringObj = new StringObject(name); - return stringObj; - default: - throw new DMLRuntimeException("Unknown value type: " + vt + " for variable: " + name); - } + return ScalarObjectFactory.createScalarObject(vt, name); } else { Data obj = getVariable(name); - if (obj == null) { + if (obj == null) throw new DMLRuntimeException("Unknown variable: " + name); - } return (ScalarObject) obj; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java index 67afdc8..c65553b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -151,9 +151,18 @@ public class FunctionCallCPInstruction extends CPInstruction //get input matrix/frame/scalar currFormalParamValue = (operand.getDataType()!=DataType.SCALAR) ? ec.getVariable(varname) : ec.getScalarInput(varname, operand.getValueType(), operand.isLiteral()); + + //graceful value type conversion for scalar inputs with wrong type + if( currFormalParamValue.getDataType() == DataType.SCALAR + && currFormalParamValue.getValueType() != operand.getValueType() ) + { + ScalarObject so = (ScalarObject) currFormalParamValue; + currFormalParamValue = ScalarObjectFactory + .createScalarObject(operand.getValueType(), so); + } } - functionVariables.put(currFormalParamName,currFormalParamValue); + functionVariables.put(currFormalParamName, currFormalParamValue); } // Pin the input variables so that they do not get deleted http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/53fe1ae6/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java index dfb0726..ea2c169 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarObjectFactory.java @@ -19,18 +19,50 @@ package org.apache.sysml.runtime.instructions.cp; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.util.UtilFunctions; public abstract class ScalarObjectFactory { - + public static ScalarObject createScalarObject(ValueType vt, String value) { + switch( vt ) { + case INT: return new IntObject(UtilFunctions.parseToLong(value)); + case DOUBLE: return new DoubleObject(Double.parseDouble(value)); + case BOOLEAN: return new BooleanObject(Boolean.parseBoolean(value)); + case STRING: return new StringObject(value); + default: throw new RuntimeException("Unsupported scalar value type: "+vt.name()); + } + } + public static ScalarObject createScalarObject(ValueType vt, Object obj) { switch( vt ) { case BOOLEAN: return new BooleanObject((Boolean)obj); case INT: return new IntObject((Long)obj); case DOUBLE: return new DoubleObject((Double)obj); case STRING: return new StringObject((String)obj); - default: throw new RuntimeException("Unsupported scalar object type: "+vt.toString()); + default: throw new RuntimeException("Unsupported scalar value type: "+vt.name()); + } + } + + public static ScalarObject createScalarObject(ValueType vt, ScalarObject so) { + switch( vt ) { + case DOUBLE: return new DoubleObject(so.getDoubleValue()); + case INT: return new IntObject(so.getLongValue()); + case BOOLEAN: return new BooleanObject(so.getBooleanValue()); + case STRING: return new StringObject(so.getStringValue()); + default: throw new RuntimeException("Unsupported scalar value type: "+vt.name()); + } + } + + public static ScalarObject createScalarObject(ValueType vt, LiteralOp lit) throws HopsException { + switch( vt ) { + case DOUBLE: return new DoubleObject(lit.getDoubleValue()); + case INT: return new IntObject(lit.getLongValue()); + case BOOLEAN: return new BooleanObject(lit.getBooleanValue()); + case STRING: return new StringObject(lit.getStringValue()); + default: throw new RuntimeException("Unsupported scalar value type: "+vt.name()); } } }
