Repository: incubator-systemml Updated Branches: refs/heads/master 686363208 -> 19e21744c
[SYSTEMML-1590] Fix codegen handling of unsupported row aggregates Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19e21744 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19e21744 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19e21744 Branch: refs/heads/master Commit: 19e21744c86adbedf6098906808c2c6327659cfe Parents: 6863632 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun May 7 20:49:38 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun May 7 20:49:46 2017 -0700 ---------------------------------------------------------------------- .../hops/codegen/template/TemplateRow.java | 6 ++-- .../functions/codegen/RowAggTmplTest.java | 18 ++++++++++- .../scripts/functions/codegen/rowAggPattern16.R | 33 ++++++++++++++++++++ .../functions/codegen/rowAggPattern16.dml | 27 ++++++++++++++++ 4 files changed, 81 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 3af8be4..3f947c8 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -78,7 +78,8 @@ public class TemplateRow extends TemplateBase || (hop instanceof AggBinaryOp && hop.getDim2()==1 && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol - && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1); + && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1 + && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)); } @Override @@ -89,7 +90,8 @@ public class TemplateRow extends TemplateBase || HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop)) - || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol) + || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol + && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) || (hop instanceof AggBinaryOp && hop.getDim1()>1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 4037edb..b7f82a7 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -51,6 +51,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME13 = TEST_NAME+"13"; //rowSums(X)+rowSums(Y) private static final String TEST_NAME14 = TEST_NAME+"14"; //colSums(max(floor(round(abs(min(sign(X+Y),1)))),7)) private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn - softmax backward (partially) + private static final String TEST_NAME16 = TEST_NAME+"16"; //Y=X-rowIndexMax(X); R=Y/rowSums(Y) private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -62,7 +63,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=15; i++) + for(int i=1; i<=16; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -291,6 +292,21 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME15, false, ExecType.SPARK ); } + @Test + public void testCodegenRowAggRewrite16CP() { + testCodegenIntegration( TEST_NAME16, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg16CP() { + testCodegenIntegration( TEST_NAME16, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg16SP() { + testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.R b/src/test/scripts/functions/codegen/rowAggPattern16.R new file mode 100644 index 0000000..a4e9184 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern16.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + + +X = matrix(seq(1,1500), 150, 10, byrow=TRUE); + +Y1 = X - max.col(X, ties.method="last") +R = Y1 / rowSums(Y1) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19e21744/src/test/scripts/functions/codegen/rowAggPattern16.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern16.dml b/src/test/scripts/functions/codegen/rowAggPattern16.dml new file mode 100644 index 0000000..e0558f6 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern16.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,1500), rows=150, cols=10); + +Y1 = X - rowIndexMax(X) +R = Y1 / rowSums(Y1) + +write(R, $1)