Repository: systemml
Updated Branches:
  refs/heads/master 51db735eb -> c61b94c97


[SYSTEMML-2375] Improved size inference nary ops (cbind/rbind/min/max)

This patch improves the size propagation for nary ops by (1) adding the
missing worst-case size inference, and (2) computing the nnz for exact
size propagation.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c61b94c9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c61b94c9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c61b94c9

Branch: refs/heads/master
Commit: c61b94c975ac9019c9e9f0187dda5b23dbac61e7
Parents: 51db735
Author: Matthias Boehm <[email protected]>
Authored: Mon Jun 18 18:50:50 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 18 18:50:50 2018 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/NaryOp.java | 22 ++++++++++
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 43 +++++++++++++++++++-
 2 files changed, 63 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c61b94c9/src/main/java/org/apache/sysml/hops/NaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/NaryOp.java 
b/src/main/java/org/apache/sysml/hops/NaryOp.java
index 0847659..db03a23 100644
--- a/src/main/java/org/apache/sysml/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/NaryOp.java
@@ -25,6 +25,7 @@ import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.lops.Nary;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 
 /**
  * The NaryOp Hop allows for a variable number of operands. Functionality
@@ -180,7 +181,26 @@ public class NaryOp extends Hop {
        }
 
        @Override
+       @SuppressWarnings("incomplete-switch")
        protected long[] inferOutputCharacteristics(MemoTable memo) {
+               if( !getDataType().isScalar() ) {
+                       MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
+                       
+                       switch( _op ) {
+                               case CBIND: return new long[]{
+                                       HopRewriteUtils.getMaxInputDim(mc, 
true),
+                                       
HopRewriteUtils.getSumValidInputDims(mc, false),
+                                       HopRewriteUtils.getSumValidInputNnz(mc, 
true)};
+                               case RBIND: return new long[]{
+                                       
HopRewriteUtils.getSumValidInputDims(mc, true),
+                                       HopRewriteUtils.getMaxInputDim(mc, 
false),
+                                       HopRewriteUtils.getSumValidInputNnz(mc, 
true)};
+                               case MIN:
+                               case MAX: return new long[]{
+                                       HopRewriteUtils.getMaxInputDim(this, 
true),
+                                       HopRewriteUtils.getMaxInputDim(this, 
false), -1};
+                       }
+               }
                return null; //do nothing
        }
        
@@ -190,10 +210,12 @@ public class NaryOp extends Hop {
                        case CBIND:
                                setDim1(HopRewriteUtils.getMaxInputDim(this, 
true));
                                
setDim2(HopRewriteUtils.getSumValidInputDims(this, false));
+                               
setNnz(HopRewriteUtils.getSumValidInputNnz(this));
                                break;
                        case RBIND:
                                
setDim1(HopRewriteUtils.getSumValidInputDims(this, true));
                                setDim2(HopRewriteUtils.getMaxInputDim(this, 
false));
+                               
setNnz(HopRewriteUtils.getSumValidInputNnz(this));
                                break;
                        case MIN:
                        case MAX:

http://git-wip-us.apache.org/repos/asf/systemml/blob/c61b94c9/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 7872c91..48b95cc 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -73,6 +73,7 @@ import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
 import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -1357,14 +1358,14 @@ public class HopRewriteUtils
        
        public static long getMaxInputDim(Hop hop, boolean dim1) {
                return hop.getInput().stream().mapToLong(
-                       h -> (dim1?h.getDim1():h.getDim2())).max().orElse(-1);
+                       h -> (dim1 ? h.getDim1() : 
h.getDim2())).max().orElse(-1);
        }
        
        public static long getSumValidInputDims(Hop hop, boolean dim1) {
                if( !hasValidInputDims(hop, dim1) )
                        return -1;
                return hop.getInput().stream().mapToLong(
-                       h -> (dim1?h.getDim1():h.getDim2())).sum();
+                       h -> (dim1 ? h.getDim1() : h.getDim2())).sum();
        }
        
        public static boolean hasValidInputDims(Hop hop, boolean dim1) {
@@ -1372,6 +1373,44 @@ public class HopRewriteUtils
                        h -> dim1 ? h.rowsKnown() : h.colsKnown());
        }
        
+       public static long getSumValidInputNnz(Hop hop) {
+               if( !hasValidInputNnz(hop) )
+                       return -1;
+               return hop.getInput().stream().mapToLong(h -> h.getNnz()).sum();
+       }
+       
+       public static boolean hasValidInputNnz(Hop hop) {
+               return hop.getInput().stream().allMatch(h -> h.getNnz() >= 0);
+       }
+       
+       public static long getMaxInputDim(MatrixCharacteristics[] mc, boolean 
dim1) {
+               return Arrays.stream(mc).mapToLong(
+                       h -> (dim1 ? h.getRows() : 
h.getRows())).max().orElse(-1);
+       }
+       
+       public static long getSumValidInputDims(MatrixCharacteristics[] mc, 
boolean dim1) {
+               if( !hasValidInputDims(mc, dim1) )
+                       return -1;
+               return Arrays.stream(mc).mapToLong(
+                       h -> (dim1 ? h.getRows() : h.getCols())).sum();
+       }
+       
+       public static boolean hasValidInputDims(MatrixCharacteristics[] mc, 
boolean dim1) {
+               return Arrays.stream(mc).allMatch(
+                       h -> dim1 ? h.rowsKnown() : h.colsKnown());
+       }
+       
+       public static long getSumValidInputNnz(MatrixCharacteristics[] mc, 
boolean worstcase) {
+               if( !hasValidInputNnz(mc, worstcase) )
+                       return -1;
+               return Arrays.stream(mc).mapToLong(h -> h.nnzKnown() ?
+                       h.getNonZeros() : h.getLength()).sum();
+       }
+       
+       public static boolean hasValidInputNnz(MatrixCharacteristics[] mc, 
boolean worstcase) {
+               return Arrays.stream(mc).allMatch(h -> h.nnzKnown() || 
(worstcase && h.dimsKnown()));
+       }
+       
        public static boolean containsSecondOrderBuiltin(ArrayList<Hop> roots) {
                Hop.resetVisitStatus(roots);
                return roots.stream().anyMatch(r -> 
containsSecondOrderBuiltin(r));

Reply via email to