[MINOR] Fix analysis of sparse-safeness for codegen cell/magg ops Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c70cb116 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c70cb116 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c70cb116
Branch: refs/heads/master Commit: c70cb1166f4ec6c79d10248727a3eb7b85f70360 Parents: 78a3808 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun Oct 22 18:57:35 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun Oct 22 18:57:35 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/codegen/template/TemplateCell.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c70cb116/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index c9b0734..4f3d4f4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -322,10 +322,12 @@ public class TemplateCell extends TemplateBase protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) { boolean ret = true; for( int i=0; i<outputs.size() && ret; i++ ) { - ret &= (HopRewriteUtils.isBinary(roots.get(i), OpOp2.MULT) - && roots.get(i).getInput().contains(mainInput)) - || (HopRewriteUtils.isBinary(roots.get(i), OpOp2.DIV) - && roots.get(i).getInput().get(0) == mainInput) + Hop root = (roots.get(i) instanceof AggUnaryOp || roots.get(i) + instanceof AggBinaryOp) ? roots.get(i).getInput().get(0) : roots.get(i); + ret &= (HopRewriteUtils.isBinarySparseSafe(root) + && root.getInput().contains(mainInput)) + || (HopRewriteUtils.isBinary(root, OpOp2.DIV) + && root.getInput().get(0) == mainInput) || (TemplateUtils.rIsSparseSafeOnly(outputs.get(i), BinType.MULT) && TemplateUtils.rContainsInput(outputs.get(i), mainInput.getHopID())); if( onlySum )