[SYSTEMML-656] Fix constant folding rewrite (casting of output types)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/5bde577a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/5bde577a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/5bde577a

Branch: refs/heads/master
Commit: 5bde577a2414ef851bc8601b84aa5f6323ee6260
Parents: a7c3689
Author: Matthias Boehm <[email protected]>
Authored: Sat May 7 20:43:47 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat May 7 20:44:25 2016 -0700

----------------------------------------------------------------------
 .../hops/rewrite/RewriteConstantFolding.java      |  5 +++--
 .../cp/ScalarScalarArithmeticCPInstruction.java   | 18 +++++++++---------
 2 files changed, 12 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5bde577a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
index f4b57cc..c6fb6b0 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
@@ -215,10 +215,11 @@ public class RewriteConstantFolding extends HopRewriteRule
                
                pb.execute( ec );
                
-               //get scalar result (check before invocation)
+               //get scalar result (check before invocation) and create 
literal according
+               //to observed scalar output type (not hop type) for runtime 
consistency
                ScalarObject so = (ScalarObject) ec.getVariable(TMP_VARNAME);
                LiteralOp literal = null;
-               switch( bop.getValueType() ){
+               switch( so.getValueType() ){
                        case DOUBLE:  literal = new 
LiteralOp(so.getDoubleValue()); break;
                        case INT:     literal = new 
LiteralOp(so.getLongValue()); break;
                        case BOOLEAN: literal = new 
LiteralOp(so.getBooleanValue()); break;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5bde577a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java
index 7e0c7bc..030bfbc 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ScalarScalarArithmeticCPInstruction.java
@@ -51,7 +51,7 @@ public class ScalarScalarArithmeticCPInstruction extends 
ArithmeticBinaryCPInstr
                // 2) Compute the result value & make an appropriate data 
object 
                BinaryOperator dop = (BinaryOperator) _optr;
                
-               if ( input1.getValueType() == ValueType.STRING 
+               if( input1.getValueType() == ValueType.STRING 
                         || input2.getValueType() == ValueType.STRING ) 
                {
                        //pre-check (for robustness regarding too long strings)
@@ -60,29 +60,29 @@ public class ScalarScalarArithmeticCPInstruction extends 
ArithmeticBinaryCPInstr
                        StringObject.checkMaxStringLength(val1.length() + 
val2.length());
                        
                        String rval = dop.fn.execute(val1, val2);
-                       sores = (ScalarObject) new StringObject(rval);
+                       sores = new StringObject(rval);
                }
                else if ( so1 instanceof IntObject && so2 instanceof IntObject 
) {
                        if ( dop.fn instanceof Divide || dop.fn instanceof 
Power ) {
                                // If both inputs are of type INT then output 
must be an INT if operation is not divide or power
-                               double rval = dop.fn.execute ( 
so1.getLongValue(), so2.getLongValue() );
-                               sores = (ScalarObject) new DoubleObject(rval);
+                               double rval = dop.fn.execute( 
so1.getLongValue(), so2.getLongValue() );
+                               sores = new DoubleObject(rval);
                        }
                        else {
                                // If both inputs are of type INT then output 
must be an INT if operation is not divide or power
-                               double tmpVal = dop.fn.execute ( 
so1.getLongValue(), so2.getLongValue() );
+                               double tmpVal = dop.fn.execute( 
so1.getLongValue(), so2.getLongValue() );
                                //cast to long if no overflow, otherwise 
controlled exception
                                if( tmpVal > Long.MAX_VALUE )
                                        throw new DMLRuntimeException("Integer 
operation created numerical result overflow ("+tmpVal+" > 
"+Long.MAX_VALUE+").");
-                               long rval = (long) tmpVal; 
-                               sores = (ScalarObject) new IntObject(rval);
+                               sores = new IntObject((long) tmpVal);
                        }
                }
-               
+               //NOTE: boolean-boolean arithmetic covered by general case 
below in order
+               //to maintain consistency with R
                else {
                        // If either of the input is of type DOUBLE then output 
is a DOUBLE
                        double rval = dop.fn.execute ( so1.getDoubleValue(), 
so2.getDoubleValue() );
-                       sores = (ScalarObject) new DoubleObject(rval); 
+                       sores = new DoubleObject(rval); 
                }
                
                // 3) Put the result value into ProgramBlock

Reply via email to