Repository: systemml
Updated Branches:
  refs/heads/master c170374e7 -> ca5581fcc


[SYSTEMML-1443] Codegen constraint handling for distributed row ops

For distributed rowwise fused operators, the cost-based codegen plan
selector has to explicitly handle conditional blocksize constraints of
ncol(X) <= blocksize to guarantee that entire rows are available. These
constraints are conditional on the selected spark execution type, which
in turns depends on the total size of operator inputs and output (and
thus fusion decisions). The cost-based plan selector now applies a
best-effort pre-filtering of invalid partial row plans. Additionally,
any remaining invalid plans are pruned during cplan cleanup which
guarantees valid runtime plans for all selection policies.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca5581fc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca5581fc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca5581fc

Branch: refs/heads/master
Commit: ca5581fccd70a6ae974e29a9d11e6d4aafe971e4
Parents: c170374
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Aug 10 00:54:46 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Aug 10 00:54:46 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/OptimizerUtils.java   |   7 ++
 .../sysml/hops/codegen/SpoofCompiler.java       |  87 ++++++++------
 .../opt/PlanSelectionFuseCostBasedV2.java       | 116 ++++++++++++-------
 .../hops/codegen/template/CPlanMemoTable.java   |  17 ++-
 .../instructions/spark/SpoofSPInstruction.java  |   7 +-
 .../functions/codegen/AlgorithmGLM.java         |  60 ++++++++++
 .../functions/codegen/AlgorithmLinregCG.java    |  16 ++-
 7 files changed, 227 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 7f07cfc..a0a36d5 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.hops;
 
+import java.util.Arrays;
 import java.util.HashMap;
 
 import org.apache.commons.logging.Log;
@@ -769,6 +770,12 @@ public class OptimizerUtils
                return bsize;
        }
        
+       public static double getTotalMemEstimate(Hop[] in, Hop out) {
+               return Arrays.stream(in)
+                       .mapToDouble(h -> h.getOutputMemEstimate()).sum()
+                       + out.getOutputMemEstimate();
+       }
+       
        /**
         * Indicates if the given indexing range is block aligned, i.e., it 
does not require
         * global aggregation of blocks.

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index d5c9618..49a1686 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -35,6 +35,7 @@ import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeCell;
 import org.apache.sysml.hops.codegen.cplan.CNodeData;
@@ -354,13 +355,26 @@ public class SpoofCompiler
                        //context-sensitive literal replacement (only integers 
during recompile)
                        boolean compileLiterals = 
(PLAN_CACHE_POLICY==PlanCachePolicy.CONSTANT) || !recompile;
                        
-                       //construct codegen plans
-                       HashMap<Long, Pair<Hop[],CNodeTpl>>  cplans = 
constructCPlans(roots, compileLiterals);
+                       //candidate exploration of valid partial fusion plans
+                       CPlanMemoTable memo = new CPlanMemoTable();
+                       for( Hop hop : roots )
+                               rExploreCPlans(hop, memo, compileLiterals);
+                       
+                       //candidate selection of optimal fusion plan
+                       memo.pruneSuboptimal(roots);
+                       
+                       //construct actual cplan representations
+                       //note: we do not use the hop visit status due to jumps 
over fused operators which would
+                       //corrupt subsequent resets, leaving partial hops dags 
in visited status
+                       HashMap<Long, Pair<Hop[],CNodeTpl>> cplans = new 
LinkedHashMap<>();
+                       HashSet<Long> visited = new HashSet<Long>();
+                       for( Hop hop : roots )
+                               rConstructCPlans(hop, memo, cplans, 
compileLiterals, visited);
                        
                        //cleanup codegen plans (remove unnecessary inputs, fix 
hop-cnodedata mapping,
                        //remove empty templates with single cnodedata input, 
remove spurious lookups,
                        //perform common subexpression elimination)
-                       cplans = cleanupCPlans(cplans);
+                       cplans = cleanupCPlans(memo, cplans);
                        
                        //explain before modification
                        if( LOG.isTraceEnabled() && !cplans.isEmpty() ) { 
//existing cplans
@@ -476,27 +490,6 @@ public class SpoofCompiler
        
        ////////////////////
        // Codegen plan construction
-
-       private static HashMap<Long, Pair<Hop[],CNodeTpl>> 
constructCPlans(ArrayList<Hop> roots, boolean compileLiterals) throws 
DMLException
-       {
-               //explore cplan candidates
-               CPlanMemoTable memo = new CPlanMemoTable();
-               for( Hop hop : roots )
-                       rExploreCPlans(hop, memo, compileLiterals);
-               
-               //select optimal cplan candidates
-               memo.pruneSuboptimal(roots);
-               
-               //construct actual cplan representations
-               //note: we do not use the hop visit status due to jumps over 
fused operators which would
-               //corrupt subsequent resets, leaving partial hops dags in 
visited status
-               LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new 
LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>();
-               HashSet<Long> visited = new HashSet<Long>();
-               for( Hop hop : roots )
-                       rConstructCPlans(hop, memo, ret, compileLiterals, 
visited);
-               
-               return ret;
-       }
        
        private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, 
boolean compileLiterals) 
                throws DMLException
@@ -664,9 +657,10 @@ public class SpoofCompiler
         * during incremental construction. This is important as it avoids 
unnecessary 
         * redundant computation. 
         * 
+        * @param memo memoization table
         * @param cplans set of cplans
         */
-       private static HashMap<Long, Pair<Hop[],CNodeTpl>> 
cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) 
+       private static HashMap<Long, Pair<Hop[],CNodeTpl>> 
cleanupCPlans(CPlanMemoTable memo, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) 
        {
                HashMap<Long, Pair<Hop[],CNodeTpl>> cplans2 = new HashMap<Long, 
Pair<Hop[],CNodeTpl>>();
                CPlanCSERewriter cse = new CPlanCSERewriter();
@@ -711,24 +705,51 @@ public class SpoofCompiler
                        else
                                rFindAndRemoveLookup(tpl.getOutput(), in1);
                        
-                       //remove invalid row templates (e.g., due to partial 
unknowns)
-                       if( tpl instanceof CNodeRow && (in1.getNumCols() == 1
-                               || (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
-                                       && 
tpl.getOutput().getDataType().isScalar())) )
-                               cplans2.remove(e.getKey());
+                       //remove invalid row templates (e.g., unsatisfied 
blocksize constraint)
+                       if( tpl instanceof CNodeRow ) {
+                               //check for invalid row cplan over column vector
+                               if(in1.getNumCols() == 1 || 
(((CNodeRow)tpl).getRowType()==RowType.NO_AGG
+                                       && 
tpl.getOutput().getDataType().isScalar()) ) {
+                                       cplans2.remove(e.getKey());
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed invalid row 
cplan w/o agg on column vector.");
+                               }
+                               else if( OptimizerUtils.isSparkExecutionMode() 
) {
+                                       boolean isSpark = DMLScript.rtplatform 
== RUNTIME_PLATFORM.SPARK
+                                               || 
OptimizerUtils.getTotalMemEstimate(inHops, memo.getHopRefs().get(e.getKey()))
+                                                       > 
OptimizerUtils.getLocalMemBudget();
+                                       boolean invalidNcol = false;
+                                       for( Hop in : inHops )
+                                               invalidNcol |= 
(in.getDataType().isMatrix() 
+                                                       && in.getDim2() > 
in.getColsInBlock());
+                                       if( isSpark && invalidNcol ) {
+                                               cplans2.remove(e.getKey());
+                                               if( LOG.isTraceEnabled() )
+                                                       LOG.trace("Removed 
invalid row cplan w/ ncol>ncolpb.");         
+                                       }
+                               }
+                       }
                        
                        //remove cplan w/ single op and w/o agg
                        if( (tpl instanceof CNodeCell && 
((CNodeCell)tpl).getCellType()==CellType.NO_AGG
                                        && 
TemplateUtils.hasSingleOperation(tpl) )
                                || (tpl instanceof CNodeRow && 
(((CNodeRow)tpl).getRowType()==RowType.NO_AGG
-                                       || 
((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1)
+                                       || 
((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
+                                       || 
((CNodeRow)tpl).getRowType()==RowType.ROW_AGG )
                                        && 
TemplateUtils.hasSingleOperation(tpl))
                                || TemplateUtils.hasNoOperation(tpl) ) 
+                       {       
                                cplans2.remove(e.getKey());
-                               
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Removed cplan with single 
operation.");
+                       }
+                       
                        //remove cplan if empty
-                       if( tpl.getOutput() instanceof CNodeData )
+                       if( tpl.getOutput() instanceof CNodeData ) {
                                cplans2.remove(e.getKey());
+                               if( LOG.isTraceEnabled() )
+                                       LOG.trace("Removed empty cplan.");
+                       }
                        
                        //rename inputs (for codegen and plan caching)
                        tpl.renameInputs();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 717a059..e66c9c3 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -36,6 +36,7 @@ import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
@@ -44,6 +45,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.ReorgOp;
 import org.apache.sysml.hops.TernaryOp;
@@ -122,37 +124,8 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
        
        private void selectPlans(CPlanMemoTable memo, PlanPartition part) 
        {
-               //prune row aggregates with pure cellwise operations
-               for( Long hopID : part.getRoots() ) {
-                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
-                       if( me.type == TemplateType.ROW && memo.contains(hopID, 
TemplateType.CELL)
-                               && isRowTemplateWithoutAgg(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
-                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
-                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
-                               if( LOG.isTraceEnabled() ) {
-                                       LOG.trace("Removed row memo table 
entries w/o aggregation: "
-                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
-                               }
-                       }
-               }
-               
-               //prune suboptimal outer product plans that are dominated by 
outer product plans w/ same number of 
-               //references but better fusion properties (e.g., for the 
patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
-               //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this 
would unnecessarily destroy a fusion pattern.
-               for( Long hopID : part.getPartition() ) {
-                       if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) 
{
-                               List<MemoTableEntry> entries = memo.get(hopID, 
TemplateType.OUTER);
-                               MemoTableEntry me1 = entries.get(0);
-                               MemoTableEntry me2 = entries.get(1);
-                               MemoTableEntry rmEntry = 
TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
-                               if( rmEntry != null ) {
-                                       
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
-                                       
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
-                                       if( LOG.isTraceEnabled() )
-                                               LOG.trace("Removed dominated 
outer product memo table entry: " + rmEntry);
-                               }
-                       }
-               }
+               //prune special case patterns and invalid plans (e.g., 
blocksize)
+               pruneInvalidAndSpecialCasePlans(memo, part);
                
                //if no materialization points, use basic fuse-all w/ partition 
awareness
                if( part.getMatPointsExt() == null || 
part.getMatPointsExt().length==0 ) {
@@ -163,8 +136,8 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                else {
                        //obtain hop compute costs per cell once
                        HashMap<Long, Double> computeCosts = new HashMap<Long, 
Double>();
-                       for( Long hopID : part.getRoots() )
-                               rGetComputeCosts(memo.getHopRefs().get(hopID), 
part.getPartition(), computeCosts);
+                       for( Long hopID : part.getPartition() )
+                               getComputeCosts(memo.getHopRefs().get(hopID), 
computeCosts);
                        
                        //prepare pruning helpers and prune memo table w/ 
determined mat points
                        StaticCosts costs = new StaticCosts(computeCosts, 
getComputeCost(computeCosts, memo), 
@@ -595,7 +568,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                boolean ret = true;
                MemoTableEntry me = memo.getBest(current.getHopID(), 
TemplateType.ROW);
                for(int i=0; i<3; i++)
-                       if( me.isPlanRef(i) )
+                       if( me!=null && me.isPlanRef(i) )
                                ret &= rIsRowTemplateWithoutAgg(memo, 
current.getInput().get(i), visited);
                ret &= !(current instanceof AggUnaryOp || current instanceof 
AggBinaryOp);
                
@@ -603,6 +576,69 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                return ret;
        }
        
+       private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable 
memo, PlanPartition part) 
+       {       
+               //prune invalid row entries w/ violated blocksize constraint
+               if( OptimizerUtils.isSparkExecutionMode() ) {
+                       for( Long hopID : part.getPartition() ) {
+                               if( !memo.contains(hopID, TemplateType.ROW) )
+                                       continue;
+                               Hop hop = memo.getHopRefs().get(hopID);
+                               boolean isSpark = DMLScript.rtplatform == 
RUNTIME_PLATFORM.SPARK
+                                       || 
OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop)
+                                               > 
OptimizerUtils.getLocalMemBudget();
+                               boolean validNcol = true;
+                               for( Hop in : hop.getInput() )
+                                       validNcol &= in.getDataType().isScalar()
+                                               || (in.getDim2() <= 
in.getColsInBlock())
+                                               || (hop instanceof AggBinaryOp 
&& in.getDim1() <= in.getRowsInBlock()
+                                               && 
HopRewriteUtils.isTransposeOperation(in));
+                               if( isSpark && !validNcol ) {
+                                       List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW);
+                                       
memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
+                                       if( !memo.contains(hopID) )
+                                               memo.removeAllRefTo(hopID);
+                                       if( LOG.isTraceEnabled() ) {
+                                               LOG.trace("Removed row memo 
table entries w/ violated blocksize constraint ("+hopID+"): "
+                                                       + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                                       }
+                               }
+                       }
+               }
+               
+               //prune row aggregates with pure cellwise operations
+               for( Long hopID : part.getPartition() ) {
+                       MemoTableEntry me = memo.getBest(hopID, 
TemplateType.ROW);
+                       if( me != null && me.type == TemplateType.ROW && 
memo.contains(hopID, TemplateType.CELL)
+                               && isRowTemplateWithoutAgg(memo, 
memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+                               List<MemoTableEntry> blacklist = 
memo.get(hopID, TemplateType.ROW); 
+                               memo.remove(memo.getHopRefs().get(hopID), new 
HashSet<MemoTableEntry>(blacklist));
+                               if( LOG.isTraceEnabled() ) {
+                                       LOG.trace("Removed row memo table 
entries w/o aggregation: "
+                                               + 
Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+                               }
+                       }
+               }
+               
+               //prune suboptimal outer product plans that are dominated by 
outer product plans w/ same number of 
+               //references but better fusion properties (e.g., for the 
patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
+               //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this 
would unnecessarily destroy a fusion pattern.
+               for( Long hopID : part.getPartition() ) {
+                       if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) 
{
+                               List<MemoTableEntry> entries = memo.get(hopID, 
TemplateType.OUTER);
+                               MemoTableEntry me1 = entries.get(0);
+                               MemoTableEntry me2 = entries.get(1);
+                               MemoTableEntry rmEntry = 
TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
+                               if( rmEntry != null ) {
+                                       
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
+                                       
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+                                       if( LOG.isTraceEnabled() )
+                                               LOG.trace("Removed dominated 
outer product memo table entry: " + rmEntry);
+                               }
+                       }
+               }
+       }
+       
        private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop 
current, HashSet<Long> visited, 
                PlanPartition part, InterestingPoint[] matPoints, boolean[] 
plan) 
        {
@@ -751,7 +787,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                //open template if necessary, including memoization
                //under awareness of current plan choice
                MemoTableEntry best = null;
-               boolean opened = false;
+               boolean opened = (currentType == null);
                if( memo.contains(current.getHopID()) ) {
                        //note: this is the inner loop of plan enumeration and 
hence, we do not 
                        //use streams, lambda expressions, etc to avoid 
unnecessary overhead
@@ -836,16 +872,8 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                return costs;
        }
        
-       private static void rGetComputeCosts(Hop current, HashSet<Long> 
partition, HashMap<Long, Double> computeCosts) 
+       private static void getComputeCosts(Hop current, HashMap<Long, Double> 
computeCosts) 
        {
-               if( computeCosts.containsKey(current.getHopID()) 
-                       || !partition.contains(current.getHopID()) )
-                       return;
-               
-               //recursively process children
-               for( Hop c : current.getInput() )
-                       rGetComputeCosts(c, partition, computeCosts);
-               
                //get costs for given hop
                double costs = 1;
                if( current instanceof UnaryOp ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 4078060..4adec25 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -77,7 +77,8 @@ public class CPlanMemoTable
        }
        
        public boolean contains(long hopID) {
-               return _plans.containsKey(hopID);
+               return _plans.containsKey(hopID)
+                       && !_plans.get(hopID).isEmpty();
        }
        
        public boolean contains(long hopID, TemplateType type) {
@@ -151,6 +152,17 @@ public class CPlanMemoTable
                        .removeIf(p -> blackList.contains(p));
        }
        
+       public void removeAllRefTo(long hopID) {
+               //recursive removal of references
+               for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) {
+                       if( !e.getValue().isEmpty() ) {
+                               e.getValue().removeIf(p -> 
p.hasPlanRefTo(hopID));
+                               if( e.getValue().isEmpty() )
+                                       removeAllRefTo(e.getKey());
+                       }
+               }
+       }
+       
        public void setDistinct(long hopID, List<MemoTableEntry> plans) {
                _plans.put(hopID, plans.stream()
                        .distinct().collect(Collectors.toList()));
@@ -354,6 +366,9 @@ public class CPlanMemoTable
                public boolean hasPlanRef() {
                        return isPlanRef(0) || isPlanRef(1) || isPlanRef(2);
                }
+               public boolean hasPlanRefTo(long hopID) {
+                       return (input1==hopID || input2==hopID || 
input3==hopID); 
+               }
                public int countPlanRefs() {
                        return ((input1 >= 0) ? 1 : 0)
                                +  ((input2 >= 0) ? 1 : 0)

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index eae5560..90e2184 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -203,8 +203,13 @@ public class SpoofSPInstruction extends SPInstruction
                        }
                }
                else if( _class.getSuperclass() == SpoofRowwise.class ) { //row 
aggregate operator
+                       if( mcIn.getCols() > mcIn.getColsPerBlock() ) {
+                               throw new DMLRuntimeException("Invalid spark 
rowwise operator w/ ncol=" + 
+                                       mcIn.getCols()+", 
ncolpb="+mcIn.getColsPerBlock()+".");
+                       }
                        SpoofRowwise op = (SpoofRowwise) 
CodegenUtils.createInstance(_class);   
-                       RowwiseFunction fmmc = new 
RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, 
(int)mcIn.getCols());
+                       RowwiseFunction fmmc = new 
RowwiseFunction(_class.getName(),
+                               _classBytes, bcMatrices, scalars, 
(int)mcIn.getCols());
                        out = in.mapPartitionsToPair(fmmc, 
op.getRowType()==RowType.ROW_AGG
                                        || op.getRowType() == RowType.NO_AGG);
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java
index 803ec93..a48c84c 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmGLM.java
@@ -123,6 +123,66 @@ public class AlgorithmGLM extends AutomatedTestBase
                runGLMTest(GLMType.BINOMIAL_PROBIT, false, true, ExecType.CP);
        }
        
+       @Test
+       public void testGLMPoissonDenseRewritesSP() {
+               runGLMTest(GLMType.POISSON_LOG, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMPoissonSparseRewritesSP() {
+               runGLMTest(GLMType.POISSON_LOG, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMPoissonDenseSP() {
+               runGLMTest(GLMType.POISSON_LOG, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMPoissonSparseSP() {
+               runGLMTest(GLMType.POISSON_LOG, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMGammaDenseRewritesSP() {
+               runGLMTest(GLMType.GAMMA_LOG, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMGammaSparseRewritesSP() {
+               runGLMTest(GLMType.GAMMA_LOG, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMGammaDenseSP() {
+               runGLMTest(GLMType.GAMMA_LOG, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMGammaSparseSP() {
+               runGLMTest(GLMType.GAMMA_LOG, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMBinomialDenseRewritesSP() {
+               runGLMTest(GLMType.BINOMIAL_PROBIT, true, false, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMBinomialSparseRewritesSP() {
+               runGLMTest(GLMType.BINOMIAL_PROBIT, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMBinomialDenseSP() {
+               runGLMTest(GLMType.BINOMIAL_PROBIT, false, false, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testGLMBinomialSparseSP() {
+               runGLMTest(GLMType.BINOMIAL_PROBIT, false, true, 
ExecType.SPARK);
+       }
+       
        private void runGLMTest( GLMType type, boolean rewrites, boolean 
sparse, ExecType instType)
        {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca5581fc/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
index 729699f..80e4b9f 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
@@ -79,17 +79,25 @@ public class AlgorithmLinregCG extends AutomatedTestBase
                runLinregCGTest(TEST_NAME1, false, true, ExecType.CP);
        }
 
-       /*
+       @Test
+       public void testLinregCGDenseRewritesSP() {
+               runLinregCGTest(TEST_NAME1, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testLinregCGSparseRewritesSP() {
+               runLinregCGTest(TEST_NAME1, true, true, ExecType.SPARK);
+       }
+       
        @Test
        public void testLinregCGDenseSP() {
-               runGDFOTest(TEST_NAME1, false, ExecType.SPARK);
+               runLinregCGTest(TEST_NAME1, false, false, ExecType.SPARK);
        }
        
        @Test
        public void testLinregCGSparseSP() {
-               runGDFOTest(TEST_NAME1, true, ExecType.SPARK);
+               runLinregCGTest(TEST_NAME1, false, true, ExecType.SPARK);
        }
-       */
        
        private void runLinregCGTest( String testname, boolean rewrites, 
boolean sparse, ExecType instType)
        {

Reply via email to