Repository: systemml
Updated Branches:
  refs/heads/master 9d5918743 -> 14b4d5487


[SYSTEMML-1687] Worst-case size propagation for fused codegen operators

This patch extends the generic codegen hop by the ability to propagate
worst-case size information according to its propagation types.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14b4d548
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14b4d548
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14b4d548

Branch: refs/heads/master
Commit: 14b4d548723a01eb58047993153fa4845ec1450b
Parents: 9d59187
Author: Matthias Boehm <[email protected]>
Authored: Tue Jun 13 13:00:58 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jun 13 13:00:58 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/codegen/SpoofFusedOp.java | 64 +++++++++++++++++---
 1 file changed, 57 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/14b4d548/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 9f426f6..06be99b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -32,6 +32,7 @@ import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.SpoofFused;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 
 public class SpoofFusedOp extends Hop implements MultiThreadedHop
 {
@@ -90,12 +91,7 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
        protected double computeIntermediateMemEstimate(long dim1, long dim2, 
long nnz) {
                return 0;
        }
-
-       @Override
-       protected long[] inferOutputCharacteristics(MemoTable memo) {
-               return null;
-       }
-
+       
        @Override
        public Lop constructLops() throws HopsException, LopsException {
                if( getLops() != null )
@@ -140,7 +136,61 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
        public String getOpString() {
                return "spoof("+_class.getSimpleName()+")";
        }
-
+       
+       @Override
+       protected long[] inferOutputCharacteristics( MemoTable memo )
+       {
+               long[] ret = null;
+       
+               //get statistics of main input
+               MatrixCharacteristics mc = 
memo.getAllInputStats(getInput().get(0));
+               
+               if( mc.dimsKnown() ) {
+                       switch(_dimsType)
+                       {
+                               case ROW_DIMS:
+                                       ret = new long[]{mc.getRows(), 1, -1};
+                                       break;
+                               case ROW_DIMS2:
+                                       ret = new long[]{mc.getRows(), 2, -1};
+                                       break;
+                               case COLUMN_DIMS_ROWS:
+                                       ret = new long[]{mc.getCols(), 1, -1};
+                                       break;
+                               case COLUMN_DIMS_COLS:
+                                       ret = new long[]{1, mc.getCols(), -1};
+                                       break;
+                               case INPUT_DIMS:
+                                       ret = new long[]{mc.getRows(), 
mc.getCols(), -1};
+                                       break;
+                               case SCALAR:
+                                       ret = new long[]{0, 0, -1};
+                                       break;
+                               case MULTI_SCALAR:
+                                       //dim2 statically set from outside
+                                       ret = new long[]{1, _dim2, -1};
+                                       break;
+                               case ROW_RANK_DIMS: {
+                                       MatrixCharacteristics mc2 = 
memo.getAllInputStats(getInput().get(1));
+                                       if( mc2.dimsKnown() )
+                                               ret = new long[]{mc.getRows(), 
mc2.getCols(), -1};
+                                       break;
+                               }
+                               case COLUMN_RANK_DIMS: {
+                                       MatrixCharacteristics mc2 = 
memo.getAllInputStats(getInput().get(1));
+                                       if( mc2.dimsKnown() )
+                                               ret = new long[]{mc.getCols(), 
mc2.getCols(), -1};
+                                       break;
+                               }
+                               default:
+                                       throw new RuntimeException("Failed to 
infer worst-case size information "
+                                                       + "for type: 
"+_dimsType.toString());
+                       }
+               }
+               
+               return ret;
+       }
+       
        @Override
        public void refreshSizeInformation() {
                switch(_dimsType)

Reply via email to