[SYSTEMML-1699] Fix IPA nnz propagation w/ multiple function calls Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a625c642 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a625c642 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a625c642
Branch: refs/heads/master Commit: a625c6423655210fa91b41162734c9214d3aaa0d Parents: 72645d3 Author: Matthias Boehm <[email protected]> Authored: Sat Jun 17 14:41:35 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 17 14:41:35 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallSizeInfo.java | 19 ++--- .../sysml/hops/ipa/InterProceduralAnalysis.java | 2 +- .../functions/misc/IPANnzPropagationTest.java | 83 ++++++++++++++++++++ .../functions/misc/IPANnzPropagation1.dml | 30 +++++++ .../functions/misc/IPANnzPropagation2.dml | 31 ++++++++ .../functions/misc/ZPackageSuite.java | 1 + 6 files changed, 156 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java index 402e780..fb668b5 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java @@ -55,8 +55,8 @@ public class FunctionCallSizeInfo //indicators for which function arguments of valid functions it //is safe to propagate the number of non-zeros - //(mapping from function keys to set of function input HopIDs) - private final Map<String, Set<Long>> _fcandSafeNNZ; + //(mapping from function keys to set of function input positions) + private final Map<String, Set<Integer>> _fcandSafeNNZ; //indicators which literal function arguments can be safely //propagated into and replaced in the respective functions @@ -90,7 +90,7 @@ public class FunctionCallSizeInfo _fgraph = fgraph; _fcand = new HashSet<String>(); _fcandUnary = new HashSet<String>(); - _fcandSafeNNZ = new HashMap<String, Set<Long>>(); + _fcandSafeNNZ = new HashMap<String, Set<Integer>>(); _fSafeLiterals = new HashMap<String, Set<Integer>>(); constructFunctionCallSizeInfo(); @@ -176,12 +176,12 @@ public class FunctionCallSizeInfo * number of non-zeros. * * @param fkey function key - * @param inputHopID hop ID of the input + * @param pos function input position * @return true if nnz can safely be propagated */ - public boolean isSafeNnz(String fkey, long inputHopID) { + public boolean isSafeNnz(String fkey, int pos) { return _fcandSafeNNZ.containsKey(fkey) - && _fcandSafeNNZ.get(fkey).contains(inputHopID); + && _fcandSafeNNZ.get(fkey).contains(pos); } /** @@ -254,12 +254,13 @@ public class FunctionCallSizeInfo //(considered for valid functions only) for( String fkey : _fcand ) { FunctionOp first = _fgraph.getFunctionCalls(fkey).get(0); - HashSet<Long> tmp = new HashSet<Long>(); - for( Hop input : first.getInput() ) { + HashSet<Integer> tmp = new HashSet<Integer>(); + for( int j=0; j<first.getInput().size(); j++ ) { //if nnz known it is safe to propagate those nnz because for multiple calls //we checked of equivalence and hence all calls have the same nnz + Hop input = first.getInput().get(0); if( input.getNnz()>=0 ) - tmp.add(input.getHopID()); + tmp.add(j); } _fcandSafeNNZ.put(fkey, tmp); } http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 7d371ac..b813685 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -541,7 +541,7 @@ public class InterProceduralAnalysis MatrixObject mo = new MatrixObject(ValueType.DOUBLE, null); MatrixCharacteristics mc = new MatrixCharacteristics( input.getDim1(), input.getDim2(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), - fcallSizes.isSafeNnz(fkey, input.getHopID())?input.getNnz():-1 ); + fcallSizes.isSafeNnz(fkey, i)?input.getNnz():-1 ); MatrixFormatMetaData meta = new MatrixFormatMetaData(mc,null,null); mo.setMetaData(meta); vars.put(dat.getName(), mo); http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java new file mode 100644 index 0000000..2809ecf --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; + + +public class IPANnzPropagationTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "IPANnzPropagation1"; + private final static String TEST_NAME2 = "IPANnzPropagation2"; + private final static String TEST_DIR = "functions/misc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + IPANnzPropagationTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{})); + } + + @Test + public void testNnzPropgationPositive() { + runIPANnzPropgationTest(TEST_NAME1); + } + + @Test + public void testNnzPropgationNegative() { + runIPANnzPropgationTest(TEST_NAME2); + } + + + private void runIPANnzPropgationTest(String testname) + { + // Save old settings + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + RUNTIME_PLATFORM platformOld = rtplatform; + + try + { + // Setup test + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-stats", "-explain", "recompile_hops"}; + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; + + runTest(true, false, null, -1); + + //check for propagated nnz + checkNumCompiledSparkInst(testname.equals(TEST_NAME1) ? 0 : 1); + checkNumExecutedSparkInst(0); + } + finally { + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + rtplatform = platformOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/scripts/functions/misc/IPANnzPropagation1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/IPANnzPropagation1.dml b/src/test/scripts/functions/misc/IPANnzPropagation1.dml new file mode 100644 index 0000000..919dc32 --- /dev/null +++ b/src/test/scripts/functions/misc/IPANnzPropagation1.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(matrix[double] X) return (double sum) { + if( 1==1 ) {} + sum = sum(X); +} + +X = rand(rows=1000, cols=1000000000, sparsity=1e-6) +s1 = foo(X); +s2 = foo(X); +print(s1+" "+s2); http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/scripts/functions/misc/IPANnzPropagation2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/IPANnzPropagation2.dml b/src/test/scripts/functions/misc/IPANnzPropagation2.dml new file mode 100644 index 0000000..86990ca --- /dev/null +++ b/src/test/scripts/functions/misc/IPANnzPropagation2.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(matrix[double] X) return (double sum) { + if( 1==1 ) {} + sum = sum(X); +} + +X = rand(rows=1000, cols=1000000000, sparsity=1e-6) +s1 = foo(X); +X = rand(rows=1000, cols=1000000000, sparsity=1e-7) +s2 = foo(X); +print(s1+" "+s2); http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 9917bb3..e352e6d 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -36,6 +36,7 @@ import org.junit.runners.Suite; InvalidFunctionSignatureTest.class, IPAConstantFoldingScalarVariablePropagationTest.class, IPALiteralReplacementTest.class, + IPANnzPropagationTest.class, IPAScalarRecursionTest.class, IPAScalarVariablePropagationTest.class, IPAUnknownRecursionTest.class,
