Repository: systemml Updated Branches: refs/heads/master 5a155f3d2 -> 5069f9781
[SYSTEMML-2117] Fix missing codegen row vector pow2/mult2 rewrites This patch adds missing codegen plan rewrites for vector pow(X,2) to pow2(X) and mult(X, 2) to mult2(X) which were only applied for scalar operations so far. Furthermore, this also includes a fix of unnecessary imports and log instances due to a recent cleanup of the validation phase. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5069f978 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5069f978 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5069f978 Branch: refs/heads/master Commit: 5069f978174640903740cd3daf3430dd7b4ad7b8 Parents: 5a155f3 Author: Matthias Boehm <[email protected]> Authored: Fri May 18 17:28:34 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri May 18 17:28:34 2018 -0700 ---------------------------------------------------------------------- .../hops/codegen/template/CPlanOpRewriter.java | 25 ++++++++++++++++---- .../org/apache/sysml/parser/DMLProgram.java | 4 ---- 2 files changed, 21 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5069f978/src/main/java/org/apache/sysml/hops/codegen/template/CPlanOpRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanOpRewriter.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanOpRewriter.java index 8ec750c..060e1aa 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanOpRewriter.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanOpRewriter.java @@ -67,10 +67,13 @@ public class CPlanOpRewriter node.getInput().set(i, rSimplifyCNode(node.getInput().get(i))); //apply all node-local simplification rewrites - node = rewriteRowCountNnz(node); //rowSums(X!=0) -> rowNnz(X) - node = rewriteRowSumSq(node); //rowSums(X^2) -> rowSumSqs(X) - node = rewriteBinaryPow2(node); //x^2 -> x*x - node = rewriteBinaryMult2(node); //x*2 -> x+x; + node = rewriteRowCountNnz(node); //rowSums(X!=0) -> rowNnz(X) + node = rewriteRowSumSq(node); //rowSums(X^2) -> rowSumSqs(X) + node = rewriteBinaryPow2(node); //x^2 -> x*x + node = rewriteBinaryPow2Vect(node); //X^2 -> X*X + node = rewriteBinaryMult2(node); //x*2 -> x+x; + node = rewriteBinaryMult2Vect(node); //X*2 -> X+X; + return node; } @@ -97,6 +100,13 @@ public class CPlanOpRewriter new CNodeUnary(node.getInput().get(0), UnaryType.POW2) : node; } + private static CNode rewriteBinaryPow2Vect(CNode node) { + return (TemplateUtils.isBinary(node, BinType.VECT_POW_SCALAR) + && node.getInput().get(1).isLiteral() + && node.getInput().get(1).getVarname().equals("2")) ? + new CNodeUnary(node.getInput().get(0), UnaryType.VECT_POW2) : node; + } + private static CNode rewriteBinaryMult2(CNode node) { return (TemplateUtils.isBinary(node, BinType.MULT) && node.getInput().get(1).isLiteral() @@ -104,6 +114,13 @@ public class CPlanOpRewriter new CNodeUnary(node.getInput().get(0), UnaryType.MULT2) : node; } + private static CNode rewriteBinaryMult2Vect(CNode node) { + return (TemplateUtils.isBinary(node, BinType.VECT_MULT) + && node.getInput().get(1).isLiteral() + && node.getInput().get(1).getVarname().equals("2")) ? + new CNodeUnary(node.getInput().get(0), UnaryType.VECT_MULT2) : node; + } + private static CNodeTpl rewriteRemoveOuterNeq0(CNodeTpl tpl) { if( tpl instanceof CNodeOuterProduct ) rFindAndRemoveBinaryMS(tpl.getOutput(), (CNodeData) http://git-wip-us.apache.org/repos/asf/systemml/blob/5069f978/src/main/java/org/apache/sysml/parser/DMLProgram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java b/src/main/java/org/apache/sysml/parser/DMLProgram.java index b800d3c..f3ebeff 100644 --- a/src/main/java/org/apache/sysml/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java @@ -22,9 +22,6 @@ package org.apache.sysml.parser; import java.util.ArrayList; import java.util.HashMap; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import org.apache.sysml.runtime.controlprogram.Program; @@ -35,7 +32,6 @@ public class DMLProgram private HashMap<String,DMLProgram> _namespaces; public static final String DEFAULT_NAMESPACE = ".defaultNS"; public static final String INTERNAL_NAMESPACE = "_internal"; // used for multi-return builtin functions - private static final Log LOG = LogFactory.getLog(DMLProgram.class.getName()); public DMLProgram(){ _blocks = new ArrayList<>();
