[MINOR] Fix uaggouterchain compilation (output data types)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6f2c885e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6f2c885e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6f2c885e

Branch: refs/heads/master
Commit: 6f2c885e8aad480349e039fcd0390feb341b3639
Parents: f9020a1
Author: Matthias Boehm <[email protected]>
Authored: Thu May 10 12:27:58 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu May 10 12:28:57 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  8 +++---
 .../cp/UaggOuterChainCPInstruction.java         | 27 +++++---------------
 .../binary/matrix/UaggOuterChainTest.java       |  3 +--
 3 files changed, 12 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 136d2d6..d3e0570 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -191,14 +191,14 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                                        
                                                if (getDataType() == 
DataType.SCALAR) {
                                                        UnaryCP unary1 = new 
UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
-                                                                               
            getDataType(), getValueType());
+                                                               getDataType(), 
getValueType());
                                                        
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
                                                        setLineNumbers(unary1);
-                                                       setLops(unary1);
+                                                       agg1 = unary1;
                                                }
                                        
-                                       }                               
-                                       else { //general case           
+                                       }
+                                       else { //general case
                                                int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                                                agg1 = new 
PartialAggregate(input.constructLops(), 
                                                                
HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), 
getDataType(),getValueType(), et, k);

http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
index 908e5bd..e6dd403 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
@@ -74,7 +74,7 @@ public class UaggOuterChainCPInstruction extends 
UnaryCPInstruction {
                boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || 
_uaggOp.indexFn instanceof ReduceAll
                                || 
!LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
 
-               MatrixBlock mbLeft = null, mbRight = null, mbOut = null;        
        
+               MatrixBlock mbLeft = null, mbRight = null, mbOut = null;
                //get the main data input
                if( rightCached ) { 
                        mbLeft = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
@@ -94,26 +94,13 @@ public class UaggOuterChainCPInstruction extends 
UnaryCPInstruction {
                if( _uaggOp.aggOp.correctionExists )
                        
mbOut.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation);
                
-               String output_name = output.getName();
-               //final aggregation if required
-               if(_uaggOp.indexFn instanceof ReduceAll ) //RC AGG (output is 
scalar)
-               {
-                       //create and set output scalar
-                       ScalarObject ret = null;
-                       switch( output.getValueType() ) {
-                               case DOUBLE:  ret = new 
DoubleObject(mbOut.quickGetValue(0, 0)); break;
-                               
-                               default: 
-                                       throw new DMLRuntimeException("Invalid 
output value type: "+output.getValueType());
-                       }
-                       ec.setScalarOutput(output_name, ret);
+               if(_uaggOp.indexFn instanceof ReduceAll ) { //RC AGG (output is 
scalar)
+                       ec.setMatrixOutput(output.getName(), new MatrixBlock(
+                               mbOut.quickGetValue(0, 0)), 
getExtendedOpcode());
                }
-               else //R/C AGG (output is rdd)
-               {       
-                       //Additional memory requirement to convert from dense 
to sparse can be leveraged from released memory needed for input data above.
+               else { //R/C AGG (output is rdd)
                        mbOut.examSparsity();
-                       ec.setMatrixOutput(output_name, mbOut, 
getExtendedOpcode());
+                       ec.setMatrixOutput(output.getName(), mbOut, 
getExtendedOpcode());
                }
-               
-       }               
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
index 04a00c9..e031b53 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
@@ -44,7 +44,6 @@ import org.apache.sysml.utils.Statistics;
  */
 public class UaggOuterChainTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME1 = "UaggOuterChain";
        private final static String TEST_DIR = "functions/binary/matrix/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
UaggOuterChainTest.class.getSimpleName() + "/";
@@ -1318,7 +1317,7 @@ public class UaggOuterChainTest extends AutomatedTestBase
                        
                        loadTestConfiguration(config, TEST_CACHE_DIR);
                        
-                       String HOME = SCRIPT_DIR + TEST_DIR;                    
+                       String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + suffix + 
strSumTypeSuffix + ".dml";
                        programArgs = new String[]{"-stats", 
"-explain","-args", 
                                input("A"), input("B"), output("C")};

Reply via email to