Repository: incubator-systemml Updated Branches: refs/heads/master 5db91308f -> 149562eca
[SYSTEMML-1515] Generalized codegen cell template (sideways row vectors) This patch generalizes the existing codegen cell template by allowing sideways row vectors, i.e., matrix-row vector binary operations, in addition to sideways column vectors and matrices. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b70ee453 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b70ee453 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b70ee453 Branch: refs/heads/master Commit: b70ee45330a457e3f0ee61c499306fb2518997c7 Parents: 5db9130 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 13 14:18:47 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 14 12:46:36 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 6 +-- .../sysml/hops/codegen/cplan/CNodeUnary.java | 6 ++- .../hops/codegen/template/TemplateCell.java | 42 +++++--------------- .../hops/codegen/template/TemplateRow.java | 24 +++-------- .../hops/codegen/template/TemplateUtils.java | 20 ++++++++-- .../functions/codegen/CellwiseTmplTest.java | 24 ++++++++++- .../scripts/functions/codegen/cellwisetmpl14.R | 31 +++++++++++++++ .../functions/codegen/cellwisetmpl14.dml | 27 +++++++++++++ 8 files changed, 118 insertions(+), 62 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index fdb8d9d..1f0644b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -41,8 +41,6 @@ import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.TemplateBase; import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; @@ -641,9 +639,7 @@ public class SpoofCompiler private static void rFindAndRemoveLookup(CNode node, CNodeData mainInput) { for( int i=0; i<node.getInput().size(); i++ ) { CNode tmp = node.getInput().get(i); - if( tmp instanceof CNodeUnary && (((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_R - || ((CNodeUnary)tmp).getType()==UnaryType.LOOKUP_RC) - && tmp.getInput().get(0) instanceof CNodeData + if( TemplateUtils.isLookup(tmp) && tmp.getInput().get(0) instanceof CNodeData && ((CNodeData)tmp.getInput().get(0)).getHopID()==mainInput.getHopID() ) { node.getInput().set(i, tmp.getInput().get(0)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 30752a2..7808421 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -28,7 +28,7 @@ import org.apache.sysml.parser.Expression.DataType; public class CNodeUnary extends CNode { public enum UnaryType { - LOOKUP_R, LOOKUP_RC, LOOKUP0, //codegen specific + LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific ROW_SUMS, ROW_MINS, ROW_MAXS, //codegen specific VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG, VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN, @@ -73,6 +73,8 @@ public class CNodeUnary extends CNode return " double %TMP% = FastMath.exp(%IN1%);\n"; case LOOKUP_R: return " double %TMP% = getValue(%IN1%, rowIndex);\n"; + case LOOKUP_C: + return " double %TMP% = getValue(%IN1%, colIndex);\n"; case LOOKUP_RC: return " double %TMP% = getValue(%IN1%, rowIndex*n+colIndex);\n"; case LOOKUP0: @@ -207,6 +209,7 @@ public class CNodeUnary extends CNode case VECT_FLOOR: case VECT_SIGN: return "u(v"+_type.name().toLowerCase()+")"; case LOOKUP_R: return "u(ixr)"; + case LOOKUP_C: return "u(ixc)"; case LOOKUP_RC: return "u(ixrc)"; case LOOKUP0: return "u(ix0)"; case POW2: return "^2"; @@ -237,6 +240,7 @@ public class CNodeUnary extends CNode case ROW_MAXS: case EXP: case LOOKUP_R: + case LOOKUP_C: case LOOKUP_RC: case LOOKUP0: case POW2: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 95f6643..d5ac99c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -165,10 +165,7 @@ public class TemplateCell extends TemplateBase if(hop instanceof UnaryOp) { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); String primitiveOpName = ((UnaryOp)hop).getOp().name(); out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); @@ -180,17 +177,9 @@ public class TemplateCell extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); String primitiveOpName = bop.getOp().name(); - //cdata1 is vector - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); - - //cdata2 is vector - if( TemplateUtils.isColVector(cdata2) ) - cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - else if( cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix() ) - cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_RC); + //add lookups if required + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); + cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); if( bop.getOp()==OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) out = new CNodeUnary(cdata1, UnaryType.POW2); @@ -206,17 +195,9 @@ public class TemplateCell extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID()); - //cdata1 is vector - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); - - //cdata3 is vector - if( TemplateUtils.isColVector(cdata3) ) - cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_R); - else if( cdata3 instanceof CNodeData && hop.getInput().get(2).getDataType().isMatrix() ) - cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_RC); + //add lookups if required + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); + cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); //construct ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, @@ -225,10 +206,7 @@ public class TemplateCell extends TemplateBase else if( hop instanceof ParameterizedBuiltinOp ) { CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); @@ -290,8 +268,8 @@ public class TemplateCell extends TemplateBase isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar()); isBinaryMatrixVector = hop.dimsKnown() - && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right) && !TemplateUtils.isBinaryMatrixRowVector(hop)) - || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left) && !TemplateUtils.isBinaryMatrixRowVector(hop)) ); + && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) + || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)) ); isBinaryMatrixMatrixDense = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix() && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index ca9776d..3af8be4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -227,10 +227,7 @@ public class TemplateRow extends TemplateBase } else //general scalar case { - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); String primitiveOpName = ((UnaryOp)hop).getOp().toString(); out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); @@ -271,17 +268,9 @@ public class TemplateRow extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID()); - //cdata1 is vector - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); - - //cdata3 is vector - if( TemplateUtils.isColVector(cdata3) ) - cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_R); - else if( cdata3 instanceof CNodeData && hop.getInput().get(2).getDataType().isMatrix() ) - cdata3 = new CNodeUnary(cdata3, UnaryType.LOOKUP_RC); + //add lookups if required + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); + cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); //construct ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, @@ -290,10 +279,7 @@ public class TemplateRow extends TemplateBase else if( hop instanceof ParameterizedBuiltinOp ) { CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); - if( TemplateUtils.isColVector(cdata1) ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) - cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 8811cb8..0a19b56 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -74,6 +74,17 @@ public class TemplateUtils && hop.getNumRows() == 1 && hop.getNumCols() != 1); } + public static CNode wrapLookupIfNecessary(CNode node, Hop hop) { + CNode ret = node; + if( isColVector(node) ) + ret = new CNodeUnary(node, UnaryType.LOOKUP_R); + else if( isRowVector(node) ) + ret = new CNodeUnary(node, UnaryType.LOOKUP_C); + else if( node instanceof CNodeData && hop.getDataType().isMatrix() ) + ret = new CNodeUnary(node, UnaryType.LOOKUP_RC); + return ret; + } + public static boolean isMatrix(Hop hop) { return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1); } @@ -256,9 +267,12 @@ public class TemplateUtils } public static boolean isLookup(CNode node) { - return (node instanceof CNodeUnary - && (((CNodeUnary)node).getType()==UnaryType.LOOKUP_R - || ((CNodeUnary)node).getType()==UnaryType.LOOKUP_RC)); + return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC); + } + + public static boolean isUnary(CNode node, UnaryType...types) { + return node instanceof CNodeUnary + && ArrayUtils.contains(types, ((CNodeUnary)node).getType()); } public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index fc41837..fbd456f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -49,7 +49,7 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME11 = TEST_NAME+11; //replace((0 / (X - 500))+1, 0/0, 7) private static final String TEST_NAME12 = TEST_NAME+12; //((X/3) %% 0.6) + ((X/3) %/% 0.6) private static final String TEST_NAME13 = TEST_NAME+13; //min(X + 7 * Y) large - + private static final String TEST_NAME14 = TEST_NAME+14; //-2 * X + t(Y); t(Y) is rowvector private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; @@ -62,7 +62,7 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=13; i++ ) { + for( int i=1; i<=14; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } @@ -133,6 +133,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite13() { testCodegenIntegration( TEST_NAME13, true, ExecType.CP ); } + + @Test + public void testCodegenCellwiseRewrite14() { + testCodegenIntegration( TEST_NAME14, true, ExecType.CP ); + } @Test public void testCodegenCellwise1() { @@ -199,6 +204,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwise13() { testCodegenIntegration( TEST_NAME13, false, ExecType.CP ); } + + @Test + public void testCodegenCellwise14() { + testCodegenIntegration( TEST_NAME14, false, ExecType.CP ); + } @Test public void testCodegenCellwiseRewrite1_sp() { @@ -235,6 +245,16 @@ public class CellwiseTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME12, true, ExecType.SPARK ); } + @Test + public void testCodegenCellwiseRewrite13_sp() { + testCodegenIntegration( TEST_NAME13, true, ExecType.SPARK ); + } + + @Test + public void testCodegenCellwiseRewrite14_sp() { + testCodegenIntegration( TEST_NAME14, true, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldRewrites = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/test/scripts/functions/codegen/cellwisetmpl14.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl14.R b/src/test/scripts/functions/codegen/cellwisetmpl14.R new file mode 100644 index 0000000..d649da8 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl14.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(7, 2200006), 1100, 2000, byrow=TRUE); +Y = seq(1, 2000); + +R = -2 * X + (matrix(1,nrow(X),1) %*% t(Y)); + +writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b70ee453/src/test/scripts/functions/codegen/cellwisetmpl14.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl14.dml b/src/test/scripts/functions/codegen/cellwisetmpl14.dml new file mode 100644 index 0000000..1c0597b --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl14.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(7, 2200006), 1100, 2000); +Y = seq(1, 2000); + +R = -2 * X + t(Y); + +write(R, $1)
