This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a04783eed2 [SYSTEMDS-1965] Extended constant folding (support for 
ternary/nary ops)
a04783eed2 is described below

commit a04783eed26414bf56425f55d083f3e38afd2472
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Jul 19 12:10:30 2023 +0200

    [SYSTEMDS-1965] Extended constant folding (support for ternary/nary ops)
    
    This patch extends the existing constant folding by support for ternary
    (e.g., ifelse and +*) and nary (e.g., nmax, n+) operations. Furthermore
    this also includes new and now-activated old test of constant folding
    in functions during IPA.
---
 .../sysds/hops/rewrite/RewriteConstantFolding.java | 15 ++++++-
 ...nstantFoldingScalarVariablePropagationTest.java | 34 +++++++++++-----
 ...PAConstantFoldingScalarVariablePropagation2.dml |  2 +-
 ...PAConstantFoldingScalarVariablePropagation3.dml | 46 ++++++++++++++++++++++
 4 files changed, 85 insertions(+), 12 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index de5b4feacc..6980e5b661 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -34,7 +34,9 @@ import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
 import org.apache.sysds.runtime.controlprogram.Program;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -96,7 +98,8 @@ public class RewriteConstantFolding extends HopRewriteRule
                
                //fold binary op if both are literals / unary op if literal
                if( root.getDataType() == DataType.SCALAR //scalar output
-                       && ( isApplicableBinaryOp(root) || 
isApplicableUnaryOp(root) ) )
+                       && ( isApplicableUnaryOp(root) || 
isApplicableBinaryOp(root)
+                               || isApplicableTernaryOp(root) || 
isApplicableNaryOp(root) ) )
                { 
                        literal = evalScalarOperation(root); 
                }
@@ -212,6 +215,16 @@ public class RewriteConstantFolding extends HopRewriteRule
                                && hop.getDataType() == DataType.SCALAR);
        }
        
+       private static boolean isApplicableTernaryOp( Hop hop ) {
+               return HopRewriteUtils.isTernary(hop, OpOp3.IFELSE, 
OpOp3.MINUS_MULT, OpOp3.PLUS_MULT)
+                               && hop.getInput().stream().allMatch(h -> h 
instanceof LiteralOp);
+       }
+       
+       private static boolean isApplicableNaryOp( Hop hop ) {
+               return HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX, 
OpOpN.PLUS)
+                       && hop.getInput().stream().allMatch(h -> h instanceof 
LiteralOp);
+       }
+       
        private static boolean isApplicableFalseConjunctivePredicate( Hop hop ) 
{
                ArrayList<Hop> in = hop.getInput();
                return (   HopRewriteUtils.isBinary(hop, OpOp2.AND) && 
hop.getDataType().isScalar()
diff --git 
a/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
index 74b60d5ff8..843db6c874 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
@@ -25,7 +25,8 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Test;
 
 /**
@@ -47,16 +48,17 @@ public class 
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
 {
        private final static String TEST_NAME1 = 
"IPAConstantFoldingScalarVariablePropagation1";
        private final static String TEST_NAME2 = 
"IPAConstantFoldingScalarVariablePropagation2";
+       private final static String TEST_NAME3 = 
"IPAConstantFoldingScalarVariablePropagation3";
+       
        private final static String TEST_DIR = "functions/misc/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
IPAConstantFoldingScalarVariablePropagationTest.class.getSimpleName() + "/";
 
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               TestConfiguration conf1 = new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME1, new String[]{});
-               TestConfiguration conf2 = new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME2, new String[]{});
-               addTestConfiguration(TEST_NAME1, conf1);
-               addTestConfiguration(TEST_NAME2, conf2);
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{}));
        }
 
        @Test
@@ -69,20 +71,26 @@ public class 
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
                runIPAScalarVariablePropagationTest(TEST_NAME1, false);
        }
 
-       // TODO: this test is ignored because  sourcing functions from another 
script does not allow named variables, with default values.
        @Test
-       @Ignore
        public void testConstantFoldingScalarPropagation2IPASecondChance() {
                runIPAScalarVariablePropagationTest(TEST_NAME2, true);
        }
 
-       // TODO: this test is ignored because  sourcing functions from another 
script does not allow named variables, with default values.
        @Test
-       @Ignore
        public void testConstantFoldingScalarPropagation2NoIPASecondChance() {
                runIPAScalarVariablePropagationTest(TEST_NAME2, false);
        }
 
+       @Test
+       public void testConstantFoldingScalarPropagation3IPASecondChance() {
+               runIPAScalarVariablePropagationTest(TEST_NAME3, true);
+       }
+       
+       @Test
+       public void testConstantFoldingScalarPropagation3NoIPASecondChance() {
+               runIPAScalarVariablePropagationTest(TEST_NAME3, false);
+       }
+       
        /**
         * Test for static rewrites + IPA second chance compilation to allow
         * for scalar propagation (IPA) of constant-folded DAG of literals
@@ -106,7 +114,7 @@ public class 
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
                        loadTestConfiguration(config);
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{"-stats"};
+                       programArgs = new String[]{"-explain","-stats"};
                        OptimizerUtils.IPA_NUM_REPETITIONS = IPA_SECOND_CHANCE 
? 2 : 1;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                        rtplatform = ExecMode.HYBRID;
@@ -118,6 +126,12 @@ public class 
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
                        // (MB: originally, this required a second chance, but 
not anymore)
                        checkNumCompiledSparkInst(0);
                        checkNumExecutedSparkInst(0);
+                       
+                       //check successful constant folding of entire 
expressions
+                       if( testname.equals(TEST_NAME3) && IPA_SECOND_CHANCE ) {
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount("floor")==2);
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount("castvti")==2);
+                       }
                }
                finally {
                        // Reset
diff --git 
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
 
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
index ec1a7fcbf2..6c3632f8eb 100644
--- 
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
+++ 
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
@@ -42,7 +42,7 @@ Wf = 3  # filter width
 stride = 1
 pad = 1  # For same dimensions, (Hf - stride) / 2
 F1 = 32  # num conv filters in conv1
-[Wc1, bc1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+[Wc1, bc1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
 
 # Create data structure to store gradients computed in parallel
 doutc1_agg = matrix(0, rows=num_batches, cols=batch_size*F1*Hin*Win)
diff --git 
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
 
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
new file mode 100644
index 0000000000..c6bb8445ed
--- /dev/null
+++ 
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Int Hin, Int Win, Int Hf, Int Wf,
+               Int strideh, Int stridew, Int padh, Int padw)
+  return(Integer Hout, Integer Wout)
+{
+  Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+  while(FALSE){} #prevent inlining
+  Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1 + 
sqrt(6/(padh+padw))))
+
+}
+
+Hin = 224  # input height
+Win = 224  # input width
+Hf = 3  # filter height
+Wf = 3  # filter width
+stride = 1
+pad = 1  # For same dimensions, (Hf - stride) / 2
+
+[Hout1, Wout1] = foo(Hin, Win, Hf, Wf, stride, stride, pad, pad);
+
+while(FALSE){} #DAG cut
+
+[Hout2, Wout2] = foo(Hin, Win, Hf, Wf, stride, stride, pad, pad);
+
+print(Hout1+" "+Wout1+" vs "+Hout2+" "+Wout2)
+#check no ops of foo -> constant folding

Reply via email to