[MINOR] Fix uaggouterchain compilation (output data types) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6f2c885e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6f2c885e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6f2c885e
Branch: refs/heads/master Commit: 6f2c885e8aad480349e039fcd0390feb341b3639 Parents: f9020a1 Author: Matthias Boehm <[email protected]> Authored: Thu May 10 12:27:58 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu May 10 12:28:57 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 8 +++--- .../cp/UaggOuterChainCPInstruction.java | 27 +++++--------------- .../binary/matrix/UaggOuterChainTest.java | 3 +-- 3 files changed, 12 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index 136d2d6..d3e0570 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -191,14 +191,14 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), - getDataType(), getValueType()); + getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); - setLops(unary1); + agg1 = unary1; } - } - else { //general case + } + else { //general case int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); agg1 = new PartialAggregate(input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k); http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java index 908e5bd..e6dd403 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java @@ -74,7 +74,7 @@ public class UaggOuterChainCPInstruction extends UnaryCPInstruction { boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp)); - MatrixBlock mbLeft = null, mbRight = null, mbOut = null; + MatrixBlock mbLeft = null, mbRight = null, mbOut = null; //get the main data input if( rightCached ) { mbLeft = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); @@ -94,26 +94,13 @@ public class UaggOuterChainCPInstruction extends UnaryCPInstruction { if( _uaggOp.aggOp.correctionExists ) mbOut.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation); - String output_name = output.getName(); - //final aggregation if required - if(_uaggOp.indexFn instanceof ReduceAll ) //RC AGG (output is scalar) - { - //create and set output scalar - ScalarObject ret = null; - switch( output.getValueType() ) { - case DOUBLE: ret = new DoubleObject(mbOut.quickGetValue(0, 0)); break; - - default: - throw new DMLRuntimeException("Invalid output value type: "+output.getValueType()); - } - ec.setScalarOutput(output_name, ret); + if(_uaggOp.indexFn instanceof ReduceAll ) { //RC AGG (output is scalar) + ec.setMatrixOutput(output.getName(), new MatrixBlock( + mbOut.quickGetValue(0, 0)), getExtendedOpcode()); } - else //R/C AGG (output is rdd) - { - //Additional memory requirement to convert from dense to sparse can be leveraged from released memory needed for input data above. + else { //R/C AGG (output is rdd) mbOut.examSparsity(); - ec.setMatrixOutput(output_name, mbOut, getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), mbOut, getExtendedOpcode()); } - - } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java index 04a00c9..e031b53 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java @@ -44,7 +44,6 @@ import org.apache.sysml.utils.Statistics; */ public class UaggOuterChainTest extends AutomatedTestBase { - private final static String TEST_NAME1 = "UaggOuterChain"; private final static String TEST_DIR = "functions/binary/matrix/"; private final static String TEST_CLASS_DIR = TEST_DIR + UaggOuterChainTest.class.getSimpleName() + "/"; @@ -1318,7 +1317,7 @@ public class UaggOuterChainTest extends AutomatedTestBase loadTestConfiguration(config, TEST_CACHE_DIR); - String HOME = SCRIPT_DIR + TEST_DIR; + String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + suffix + strSumTypeSuffix + ".dml"; programArgs = new String[]{"-stats", "-explain","-args", input("A"), input("B"), output("C")};
