This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 811e3f474c [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
811e3f474c is described below

commit 811e3f474c7e4e1747e7b5e54ffa75e79afc1cd5
Author: Mark Dokter <[email protected]>
AuthorDate: Tue Apr 19 23:16:52 2022 +0200

    [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
    
    This rewrite fuses a vector multiplication with a row max aggregation to 
avoid an intermediate vector in Spoof's row template. Occurs when using code 
gen in components.dml.
    
    Closes #1566
---
 .../apache/sysds/hops/codegen/SpoofCompiler.java   | 15 ++++++++------
 .../sysds/hops/codegen/cplan/CNodeBinary.java      |  7 ++++++-
 .../sysds/hops/codegen/cplan/java/Binary.java      |  3 +++
 .../hops/codegen/template/CPlanOpRewriter.java     | 19 +++++++++++++++--
 .../sysds/hops/codegen/template/TemplateUtils.java | 24 ++++++++++++++++++++--
 .../sysds/runtime/codegen/LibSpoofPrimitives.java  | 18 ++++++++++++++--
 6 files changed, 73 insertions(+), 13 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index ade88775e1..55d75b092a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -38,6 +38,7 @@ import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeCell;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
@@ -941,13 +942,15 @@ public class SpoofCompiler {
                        }
                        
                        //remove cplan w/ single op and w/o agg
-                       if( (tpl instanceof CNodeCell && 
((CNodeCell)tpl).getCellType()==CellType.NO_AGG
-                                       && 
TemplateUtils.hasSingleOperation(tpl) )
-                               || (tpl instanceof CNodeRow && 
(((CNodeRow)tpl).getRowType()==RowType.NO_AGG
-                                       || 
((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
-                                       || 
((CNodeRow)tpl).getRowType()==RowType.ROW_AGG )
+                       if((tpl instanceof CNodeCell && 
((CNodeCell)tpl).getCellType()==CellType.NO_AGG
                                        && 
TemplateUtils.hasSingleOperation(tpl))
-                               || TemplateUtils.hasNoOperation(tpl) ) 
+                               || (tpl instanceof CNodeRow
+                                       && 
(((CNodeRow)tpl).getRowType()==RowType.NO_AGG
+                                               || 
((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
+                                               || 
(((CNodeRow)tpl).getRowType()==RowType.ROW_AGG  && 
!TemplateUtils.isBinary(tpl.getOutput(),
+                                                       
CNodeBinary.BinType.ROWMAXS_VECTMULT)))
+                                       && 
TemplateUtils.hasSingleOperation(tpl))
+                               || TemplateUtils.hasNoOperation(tpl))
                        {
                                cplans2.remove(e.getKey());
                                if( LOG.isTraceEnabled() )
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 2e6bcd5d48..bebf0a221b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -30,6 +30,8 @@ import 
org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
 public class CNodeBinary extends CNode {
 
        public enum BinType {
+               // Fused vect_op + aggregation
+               ROWMAXS_VECTMULT,
                //matrix multiplication operations
                DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD,
                //vector-scalar-add operations
@@ -373,7 +375,8 @@ public class CNodeBinary extends CNode {
                                _cols = _inputs.get(1)._cols;
                                _dataType = DataType.MATRIX;
                                break;
-                       
+
+                       case ROWMAXS_VECTMULT:
                        case DOT_PRODUCT:
                        
                        //SCALAR Arithmetic
@@ -407,6 +410,8 @@ public class CNodeBinary extends CNode {
                                _cols = 0;
                                _dataType= DataType.SCALAR;
                                break;
+                       default:
+                                       throw new RuntimeException("Unknown 
CNodeBinary type: " + _type);
                }
        }
        
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index ecb7878f66..40496249e5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -28,6 +28,9 @@ public class Binary extends CodeTemplate {
                boolean scalarVector, boolean scalarInput, boolean vectorVector)
        {
                switch (type) {
+                       case ROWMAXS_VECTMULT:
+                               return sparseLhs ? "\tdouble %TMP% = 
LibSpoofPrimitives.rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, 
alen);\n" :
+                                               "\tdouble %TMP% = 
LibSpoofPrimitives.rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
                        case DOT_PRODUCT:
                                return sparseLhs ? "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
                                                "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java 
b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
index 2b981ee893..b81ddac401 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
@@ -21,11 +21,14 @@ package org.apache.sysds.hops.codegen.template;
 
 import java.util.ArrayList;
 
+import org.apache.spark.sql.types.BinaryType;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
 import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
@@ -56,6 +59,9 @@ public class CPlanOpRewriter
                }
                else {
                        tpl.setOutput(rSimplifyCNode(tpl.getOutput()));
+                       if(TemplateUtils.containsFusedRowVecAgg(tpl)) {
+                               ((CNodeRow) 
tpl).setNumVectorIntermediates(((CNodeRow) tpl).getNumVectorIntermediates()-2);
+                       }
                }
                
                return tpl;
@@ -73,10 +79,19 @@ public class CPlanOpRewriter
                node = rewriteBinaryPow2Vect(node);  //X^2 -> X*X
                node = rewriteBinaryMult2(node);     //x*2 -> x+x;
                node = rewriteBinaryMult2Vect(node); //X*2 -> X+X;
-               
+               node = rewriteRowMaxsVectMult(node); // rowMaxs(G * t(c)); see 
components.dml
                return node;
        }
-       
+
+       private static CNode rewriteRowMaxsVectMult(CNode node) {
+               if(TemplateUtils.isUnary(node, UnaryType.ROW_MAXS)) {
+                       CNode input = node.getInput().get(0);
+                       if(TemplateUtils.isBinary(input, BinType.VECT_MULT))
+                               return new CNodeBinary(input.getInput().get(0), 
input.getInput().get(1), BinType.ROWMAXS_VECTMULT);
+               }
+               return node;
+       }
+
        private static CNode rewriteRowCountNnz(CNode node) {
                return (TemplateUtils.isUnary(node, UnaryType.ROW_SUMS)
                        && TemplateUtils.isBinary(node.getInput().get(0), 
BinType.VECT_NOTEQUAL_SCALAR)
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index f61305fa13..8a4e0f62c9 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -49,6 +49,7 @@ import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeNary;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
 import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
@@ -279,7 +280,11 @@ public class TemplateUtils
                return node instanceof CNodeUnary
                        && ArrayUtils.contains(types, 
((CNodeUnary)node).getType());
        }
-       
+
+       public static boolean isUnaryRowAgg(CNode node) {
+               return isUnary(node, UnaryType.ROW_MAXS, UnaryType.ROW_SUMS);
+       }
+
        public static boolean isBinary(CNode node, BinType...types) {
                return node instanceof CNodeBinary
                        && ArrayUtils.contains(types, 
((CNodeBinary)node).getType());
@@ -391,7 +396,8 @@ public class TemplateUtils
                                && !TemplateUtils.isUnary(output, 
                                        UnaryType.EXP, UnaryType.LOG, 
UnaryType.ROW_COUNTNNZS)) 
                        || (output instanceof CNodeBinary
-                               && !TemplateUtils.isBinary(output, 
BinType.VECT_OUTERMULT_ADD))
+                               && (!(TemplateUtils.isBinary(output, 
BinType.VECT_OUTERMULT_ADD) ||
+                                       !TemplateUtils.isBinary(output, 
BinType.ROWMAXS_VECTMULT))))
                        || output instanceof CNodeTernary 
                                && ((CNodeTernary)output).getType() == 
TernaryType.IFELSE)
                        && hasOnlyDataNodeOrLookupInputs(output);
@@ -687,4 +693,18 @@ public class TemplateUtils
                for( CNode input : current.getInput() )
                        rFlipVectorLookups(input);
        }
+
+       public static boolean containsFusedRowVecAgg(CNodeTpl tpl) {
+               if(!(tpl instanceof CNodeRow))
+                       return false;
+
+               if(TemplateUtils.isBinary(tpl.getOutput(), 
BinType.ROWMAXS_VECTMULT))
+                       return true;
+
+               for (CNode n : tpl.getOutput().getInput()) {
+                       if(TemplateUtils.isBinary(n, BinType.ROWMAXS_VECTMULT))
+                               return true;
+               }
+               return false;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index 905b39226d..c618e79607 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -50,9 +50,23 @@ public class LibSpoofPrimitives
        private static ThreadLocal<VectorBuffer> memPool = new 
ThreadLocal<VectorBuffer>() {
                @Override protected VectorBuffer initialValue() { return new 
VectorBuffer(0,0,0); }
        };
-       
+
+       public static double rowMaxsVectMult(double[] a, double[] b, int ai, 
int bi, int len) {
+               double val = Double.NEGATIVE_INFINITY;
+               int j=0;
+               for( int i = ai; i < ai+len; i++ )
+                       val = Math.max(a[i]*b[j++], val);
+               return val;
+       }
+
+       public static double rowMaxsVectMult(double[] a, double[] b, int[] aix, 
int ai, int bi, int len) {
+               double val = Double.NEGATIVE_INFINITY;
+               for( int i = ai; i < ai+len; i++ )
+                       val = Math.max(a[i]*b[aix[i]], val);
+               return val;
+       }
+
        // forwarded calls to LibMatrixMult
-       
        public static double dotProduct(double[] a, double[] b, int ai, int bi, 
int len) {
                if( a == null || b == null ) return 0;
                return LibMatrixMult.dotProduct(a, b, ai, bi, len);

Reply via email to