[SYSTEMML-2082] Codegen support for ternary ifelse in cell/magg tmpls This patch adds basic support for ternary ifelse operations in codegen cell and magg templates along with related tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/aa537dad Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/aa537dad Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/aa537dad Branch: refs/heads/master Commit: aa537dad43f2cf21badaedcb8629b27ad301032b Parents: 5457066 Author: Matthias Boehm <[email protected]> Authored: Tue Feb 6 20:06:28 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Tue Feb 6 20:06:28 2018 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeTernary.java | 23 ++++++++------ .../hops/codegen/template/TemplateCell.java | 12 +++++--- .../functions/codegen/CellwiseTmplTest.java | 18 ++++++++++- .../scripts/functions/codegen/cellwisetmpl18.R | 32 ++++++++++++++++++++ .../functions/codegen/cellwisetmpl18.dml | 30 ++++++++++++++++++ 5 files changed, 99 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java index 155cc8b..dc8ff82 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -27,7 +27,7 @@ public class CNodeTernary extends CNode { public enum TernaryType { PLUS_MULT, MINUS_MULT, - REPLACE, REPLACE_NAN, + REPLACE, REPLACE_NAN, IFELSE, LOOKUP_RC1, LOOKUP_RVECT1; @@ -52,7 +52,10 @@ public class CNodeTernary extends CNode case REPLACE_NAN: return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n"; - + + case IFELSE: + return " double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n"; + case LOOKUP_RC1: return sparse ? " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" : @@ -124,15 +127,14 @@ public class CNodeTernary extends CNode @Override public String toString() { switch(_type) { - case PLUS_MULT: return "t(+*)"; - case MINUS_MULT: return "t(-*)"; - case REPLACE: - case REPLACE_NAN: return "t(rplc)"; - case LOOKUP_RC1: return "u(ixrc1)"; + case PLUS_MULT: return "t(+*)"; + case MINUS_MULT: return "t(-*)"; + case REPLACE: + case REPLACE_NAN: return "t(rplc)"; + case IFELSE: return "t(ifelse)"; + case LOOKUP_RC1: return "u(ixrc1)"; case LOOKUP_RVECT1: return "u(ixrv1)"; - - default: - return super.toString(); + default: return super.toString(); } } @@ -143,6 +145,7 @@ public class CNodeTernary extends CNode case MINUS_MULT: case REPLACE: case REPLACE_NAN: + case IFELSE: case LOOKUP_RC1: _rows = 0; _cols = 0; http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 50b42ea..2b8db2a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -34,6 +34,7 @@ import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; @@ -168,7 +169,7 @@ public class TemplateCell extends TemplateBase && HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals); else { - CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); + CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); tmp.put(c.getHopID(), cdata); inHops.add(c); } @@ -208,6 +209,7 @@ public class TemplateCell extends TemplateBase //add lookups if required cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); + cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); //construct ternary cnode, primitive operation derived from OpOp3 @@ -299,11 +301,11 @@ public class TemplateCell extends TemplateBase //prepare indicators for ternary operations boolean isTernaryVectorScalarVector = false; boolean isTernaryMatrixScalarMatrixDense = false; + boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix()); if( hop instanceof TernaryOp && hop.getInput().size()==3 && hop.dimsKnown() - && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) { + && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX) ) { Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(2); - isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right); isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right); @@ -312,8 +314,8 @@ public class TemplateCell extends TemplateBase //check supported unary, binary, ternary operations return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix - || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense - || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE)); + || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse + || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE)); } protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) { http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index bd3b36a..2f44f61 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -53,6 +53,7 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME15 = TEST_NAME+15; //colMins(2*log(X)) private static final String TEST_NAME16 = TEST_NAME+16; //colSums(2*log(X)); private static final String TEST_NAME17 = TEST_NAME+17; //xor operation + private static final String TEST_NAME18 = TEST_NAME+18; //sum(ifelse(X,Y,Z)) private static final String TEST_DIR = "functions/codegen/"; @@ -66,7 +67,7 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=17; i++ ) { + for( int i=1; i<=18; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } @@ -304,6 +305,21 @@ public class CellwiseTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME17, true, ExecType.SPARK ); } + @Test + public void testCodegenCellwiseRewrite18() { + testCodegenIntegration( TEST_NAME18, true, ExecType.CP ); + } + + @Test + public void testCodegenCellwise18() { + testCodegenIntegration( TEST_NAME18, false, ExecType.CP ); + } + + @Test + public void testCodegenCellwiseRewrite18_sp() { + testCodegenIntegration( TEST_NAME18, true, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.R b/src/test/scripts/functions/codegen/cellwisetmpl18.R new file mode 100644 index 0000000..e6a275a --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl18.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(-1000, 198999), 1000, 200, byrow=TRUE); +Y = matrix(seq(0, 199999), 1000, 200, byrow=TRUE); +Z = matrix(seq(1000, 200999), 1000, 200, byrow=TRUE); + +R = as.matrix(sum(as.numeric(ifelse(X,Y,Z)))); + +writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.dml b/src/test/scripts/functions/codegen/cellwisetmpl18.dml new file mode 100644 index 0000000..c178dd3 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl18.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(-1000, 198999), 1000, 200); +Y = matrix(seq(0, 199999), 1000, 200); +Z = matrix(seq(1000, 200999), 1000, 200); + +while(FALSE){} + +R = as.matrix(sum(ifelse(X,Y,Z))); + +write(R, $1)
