Repository: incubator-systemml Updated Branches: refs/heads/master 8f7cf77be -> 2e48d951b
[SYSTEMML-1447] Extended code generator (replace in rowagg/cell tmpls) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2e48d951 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2e48d951 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2e48d951 Branch: refs/heads/master Commit: 2e48d951b825fe4ef85dc13f6d69934b8cadfe46 Parents: 8f7cf77 Author: Matthias Boehm <[email protected]> Authored: Fri Mar 31 17:17:55 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Mar 31 18:21:17 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/ParameterizedBuiltinOp.java | 5 ++++ .../sysml/hops/codegen/cplan/CNodeTernary.java | 16 ++++++++-- .../hops/codegen/template/TemplateCell.java | 25 +++++++++++++--- .../hops/codegen/template/TemplateRowAgg.java | 18 +++++++++++- .../hops/codegen/template/TemplateUtils.java | 3 ++ .../functions/codegen/CellwiseTmplTest.java | 22 ++++++++++++-- .../scripts/functions/codegen/cellwisetmpl11.R | 31 ++++++++++++++++++++ .../functions/codegen/cellwisetmpl11.dml | 27 +++++++++++++++++ 8 files changed, 138 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java index fa51948..1d6828c 100644 --- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java @@ -154,6 +154,11 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop getInput().get(_paramIndexMap.get("target")) : null; } + public Hop getParameterHop(String name) { + return _paramIndexMap.containsKey(name) ? + getInput().get(_paramIndexMap.get(name)) : null; + } + @Override public void setMaxNumThreads( int k ) { _maxNumThreads = k; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java index eb26eff..a8bbcb2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -28,6 +28,7 @@ public class CNodeTernary extends CNode { public enum TernaryType { PLUS_MULT, MINUS_MULT, + REPLACE, REPLACE_NAN, LOOKUP_RC1; public static boolean contains(String value) { @@ -40,10 +41,17 @@ public class CNodeTernary extends CNode public String getTemplate(boolean sparse) { switch (this) { case PLUS_MULT: - return " double %TMP% = %IN1% + %IN2% * %IN3%;\n" ; + return " double %TMP% = %IN1% + %IN2% * %IN3%;\n"; case MINUS_MULT: - return " double %TMP% = %IN1% - %IN2% * %IN3%;\n" ; + return " double %TMP% = %IN1% - %IN2% * %IN3%;\n"; + + case REPLACE: + return " double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) " + + "&& Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n"; + + case REPLACE_NAN: + return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n"; case LOOKUP_RC1: return " double %TMP% = %IN1%[rowIndex*%IN2%+%IN3%-1];\n"; @@ -101,6 +109,8 @@ public class CNodeTernary extends CNode switch(_type) { case PLUS_MULT: return "t(+*)"; case MINUS_MULT: return "t(-*)"; + case REPLACE: + case REPLACE_NAN: return "t(rplc)"; case LOOKUP_RC1: return "u(ixrc1)"; default: return super.toString(); @@ -112,6 +122,8 @@ public class CNodeTernary extends CNode switch(_type) { case PLUS_MULT: case MINUS_MULT: + case REPLACE: + case REPLACE_NAN: case LOOKUP_RC1: _rows = 0; _cols = 0; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 87ec899..447f6d6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -34,8 +34,10 @@ import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; @@ -157,7 +159,7 @@ public class TemplateCell extends TemplateBase else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); - String primitiveOpName = ((UnaryOp)hop).getOp().toString(); + String primitiveOpName = ((UnaryOp)hop).getOp().name(); out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); } else if(hop instanceof BinaryOp) @@ -165,7 +167,7 @@ public class TemplateCell extends TemplateBase BinaryOp bop = (BinaryOp) hop; CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); - String primitiveOpName = bop.getOp().toString(); + String primitiveOpName = bop.getOp().name(); //cdata1 is vector if( TemplateUtils.isColVector(cdata1) ) @@ -207,7 +209,21 @@ public class TemplateCell extends TemplateBase //construct ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, - TernaryType.valueOf(top.getOp().toString())); + TernaryType.valueOf(top.getOp().name())); + } + else if( hop instanceof ParameterizedBuiltinOp ) + { + CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + + CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); + CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); + TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? + TernaryType.REPLACE_NAN : TernaryType.REPLACE; + out = new CNodeTernary(cdata1, cdata2, cdata3, ttype); } else if( hop instanceof IndexingOp ) { @@ -285,7 +301,8 @@ public class TemplateCell extends TemplateBase //check supported unary, binary, ternary operations return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense - || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense); + || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense + || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE)); } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java index f8f1508..2883893 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java @@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; @@ -78,7 +79,8 @@ public class TemplateRowAgg extends TemplateBase return !isClosed() && ( (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) - || (hop instanceof UnaryOp && TemplateCell.isValidOperation(hop)) + || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) + && TemplateCell.isValidOperation(hop)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol) || (hop instanceof AggBinaryOp && hop.getDim1()>1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); @@ -255,6 +257,20 @@ public class TemplateRowAgg extends TemplateBase out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString())); } + else if( hop instanceof ParameterizedBuiltinOp ) + { + CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC); + + CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); + CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); + TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? + TernaryType.REPLACE_NAN : TernaryType.REPLACE; + out = new CNodeTernary(cdata1, cdata2, cdata3, ttype); + } else if( hop instanceof IndexingOp ) { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 3f5fed9..b959638 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -30,6 +30,7 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; @@ -105,6 +106,8 @@ public class TemplateUtils return BinType.contains(((BinaryOp)h).getOp().name()); else if(h instanceof TernaryOp) return TernaryType.contains(((TernaryOp)h).getOp().name()); + else if(h instanceof ParameterizedBuiltinOp) + return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name()); return false; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index 066b761..10aa038 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -46,6 +46,8 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME8 = TEST_NAME+8; private static final String TEST_NAME9 = TEST_NAME+9; //sum((X + 7 * Y)^2) private static final String TEST_NAME10 = TEST_NAME+10; //min/max(X + 7 * Y) + private static final String TEST_NAME11 = TEST_NAME+11; //replace((0 / (X - 500))+1, 0/0, 7); + private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; @@ -58,7 +60,7 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=10; i++ ) { + for( int i=1; i<=11; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } @@ -114,6 +116,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite10() { testCodegenIntegration( TEST_NAME10, true, ExecType.CP ); } + + @Test + public void testCodegenCellwiseRewrite11() { + testCodegenIntegration( TEST_NAME11, true, ExecType.CP ); + } @Test public void testCodegenCellwise1() { @@ -165,6 +172,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwise10() { testCodegenIntegration( TEST_NAME10, false, ExecType.CP ); } + + @Test + public void testCodegenCellwise11() { + testCodegenIntegration( TEST_NAME11, false, ExecType.CP ); + } @Test public void testCodegenCellwiseRewrite1_sp() { @@ -191,6 +203,11 @@ public class CellwiseTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME10, true, ExecType.SPARK ); } + @Test + public void testCodegenCellwiseRewrite11_sp() { + testCodegenIntegration( TEST_NAME11, true, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -247,7 +264,8 @@ public class CellwiseTmplTest extends AutomatedTestBase Assert.assertTrue(!heavyHittersContainsSubString("tsmm")); else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused Assert.assertTrue(!heavyHittersContainsSubString("uamin","uamax")); - + else if( testname.equals(TEST_NAME11) ) //ensure replace is fused + Assert.assertTrue(!heavyHittersContainsSubString("replace")); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.R b/src/test/scripts/functions/codegen/cellwisetmpl11.R new file mode 100644 index 0000000..33531ba --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl11.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(7, 1006), 500, 2, byrow=TRUE); + +Y = (0 / (X - 500))+1; +R = replace(Y, is.nan(Y), 7); + +writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.dml b/src/test/scripts/functions/codegen/cellwisetmpl11.dml new file mode 100644 index 0000000..c77da08 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl11.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(7, 1006), 500, 2); + +Y = (0 / (X - 500))+1; +R = replace(target=Y, pattern=0/0, replacement=7); + +write(R, $1)
