[SYSTEMML-656] Fix constant folding rewrite (casting of output types) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/5bde577a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/5bde577a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/5bde577a
Branch: refs/heads/master Commit: 5bde577a2414ef851bc8601b84aa5f6323ee6260 Parents: a7c3689 Author: Matthias Boehm <[email protected]> Authored: Sat May 7 20:43:47 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat May 7 20:44:25 2016 -0700 ---------------------------------------------------------------------- .../hops/rewrite/RewriteConstantFolding.java | 5 +++-- .../cp/ScalarScalarArithmeticCPInstruction.java | 18 +++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5bde577a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index f4b57cc..c6fb6b0 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -215,10 +215,11 @@ public class RewriteConstantFolding extends HopRewriteRule pb.execute( ec ); - //get scalar result (check before invocation) + //get scalar result (check before invocation) and create literal according + //to observed scalar output type (not hop type) for runtime consistency ScalarObject so = (ScalarObject) ec.getVariable(TMP_VARNAME); LiteralOp literal = null; - switch( bop.getValueType() ){ + switch( so.getValueType() ){ case DOUBLE: literal = new LiteralOp(so.getDoubleValue()); break; case INT: literal = new LiteralOp(so.getLongValue()); break; case BOOLEAN: literal = new LiteralOp(so.getBooleanValue()); break; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5bde577a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java index 7e0c7bc..030bfbc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java @@ -51,7 +51,7 @@ public class ScalarScalarArithmeticCPInstruction extends ArithmeticBinaryCPInstr // 2) Compute the result value & make an appropriate data object BinaryOperator dop = (BinaryOperator) _optr; - if ( input1.getValueType() == ValueType.STRING + if( input1.getValueType() == ValueType.STRING || input2.getValueType() == ValueType.STRING ) { //pre-check (for robustness regarding too long strings) @@ -60,29 +60,29 @@ public class ScalarScalarArithmeticCPInstruction extends ArithmeticBinaryCPInstr StringObject.checkMaxStringLength(val1.length() + val2.length()); String rval = dop.fn.execute(val1, val2); - sores = (ScalarObject) new StringObject(rval); + sores = new StringObject(rval); } else if ( so1 instanceof IntObject && so2 instanceof IntObject ) { if ( dop.fn instanceof Divide || dop.fn instanceof Power ) { // If both inputs are of type INT then output must be an INT if operation is not divide or power - double rval = dop.fn.execute ( so1.getLongValue(), so2.getLongValue() ); - sores = (ScalarObject) new DoubleObject(rval); + double rval = dop.fn.execute( so1.getLongValue(), so2.getLongValue() ); + sores = new DoubleObject(rval); } else { // If both inputs are of type INT then output must be an INT if operation is not divide or power - double tmpVal = dop.fn.execute ( so1.getLongValue(), so2.getLongValue() ); + double tmpVal = dop.fn.execute( so1.getLongValue(), so2.getLongValue() ); //cast to long if no overflow, otherwise controlled exception if( tmpVal > Long.MAX_VALUE ) throw new DMLRuntimeException("Integer operation created numerical result overflow ("+tmpVal+" > "+Long.MAX_VALUE+")."); - long rval = (long) tmpVal; - sores = (ScalarObject) new IntObject(rval); + sores = new IntObject((long) tmpVal); } } - + //NOTE: boolean-boolean arithmetic covered by general case below in order + //to maintain consistency with R else { // If either of the input is of type DOUBLE then output is a DOUBLE double rval = dop.fn.execute ( so1.getDoubleValue(), so2.getDoubleValue() ); - sores = (ScalarObject) new DoubleObject(rval); + sores = new DoubleObject(rval); } // 3) Put the result value into ProgramBlock
