This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 627825c  [SYSTEMDS-2920] Fix spark rexpand instruction (variable max 
parameter)
627825c is described below

commit 627825c25d5a5938a772a78ce037c57e68611998
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Apr 1 22:33:20 2021 +0200

    [SYSTEMDS-2920] Fix spark rexpand instruction (variable max parameter)
    
    The parameterized builtin instructions receive input arguments as
    untyped parameter and internally assume some parameters as matrix inputs
    and some as scalar literals. In CP, this is fine because instruction
    patching replaces parameter markers, but in Spark this causes parse
    issues. For the concrete problem of rexpand max values, which is the
    only potentially variable parameter this patch introduces a more robust
    handling.
---
 .../spark/ParameterizedBuiltinSPInstruction.java   | 14 ++--
 .../rewrite/RewriteCTableToRExpandTest.java        | 95 ++++++++++++++--------
 .../rewrite/RewriteCTableToRExpandRightVarMax.dml  | 28 +++++++
 3 files changed, 96 insertions(+), 41 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 8e71764..9975925 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -402,8 +402,12 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        // get input rdd handle
                        JavaPairRDD<MatrixIndexes, MatrixBlock> in = 
sec.getBinaryMatrixBlockRDDHandleForVariable(rddInVar);
                        DataCharacteristics mcIn = 
sec.getDataCharacteristics(rddInVar);
-                       double maxVal = Double.parseDouble(params.get("max"));
-                       long lmaxVal = UtilFunctions.toLong(maxVal);
+                       
+                       // parse untyped parameters, w/ robust handling for 
'max'
+                       String maxValName = params.get("max");
+                       long lmaxVal = 
maxValName.startsWith(Lop.SCALAR_VAR_NAME_PREFIX) ?
+                               ec.getScalarInput(maxValName, ValueType.FP64, 
false).getLongValue() :
+                               
UtilFunctions.toLong(Double.parseDouble(maxValName));
                        boolean dirRows = params.get("dir").equals("rows");
                        boolean cast = Boolean.parseBoolean(params.get("cast"));
                        boolean ignore = 
Boolean.parseBoolean(params.get("ignore"));
@@ -420,7 +424,7 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        // execute rexpand rows/cols operation (no shuffle 
required because outputs are
                        // block-aligned with the input, i.e., one input block 
generates n output blocks)
                        JavaPairRDD<MatrixIndexes, MatrixBlock> out = in
-                               .flatMapToPair(new RDDRExpandFunction(maxVal, 
dirRows, cast, ignore, blen));
+                               .flatMapToPair(new RDDRExpandFunction(lmaxVal, 
dirRows, cast, ignore, blen));
 
                        // store output rdd handle
                        sec.setRDDHandleForVariable(output.getName(), out);
@@ -655,13 +659,13 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                implements PairFlatMapFunction<Tuple2<MatrixIndexes, 
MatrixBlock>, MatrixIndexes, MatrixBlock> {
                private static final long serialVersionUID = 
-6153643261956222601L;
 
-               private final double _maxVal;
+               private final long _maxVal;
                private final boolean _dirRows;
                private final boolean _cast;
                private final boolean _ignore;
                private final long _blen;
 
-               public RDDRExpandFunction(double maxVal, boolean dirRows, 
boolean cast, boolean ignore, long blen) {
+               public RDDRExpandFunction(long maxVal, boolean dirRows, boolean 
cast, boolean ignore, long blen) {
                        _maxVal = maxVal;
                        _dirRows = dirRows;
                        _cast = cast;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCTableToRExpandTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCTableToRExpandTest.java
index f64516f..4df292a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCTableToRExpandTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCTableToRExpandTest.java
@@ -29,12 +29,15 @@ import org.junit.Assert;
 
 public class RewriteCTableToRExpandTest extends AutomatedTestBase 
 {
-       private static final String TEST_NAME1 = 
"RewriteCTableToRExpandLeftPos";
-       private static final String TEST_NAME2 = 
"RewriteCTableToRExpandRightPos"; 
-       private static final String TEST_NAME3 = 
"RewriteCTableToRExpandLeftNeg"; 
-       private static final String TEST_NAME4 = 
"RewriteCTableToRExpandRightNeg"; 
-       private static final String TEST_NAME5 = 
"RewriteCTableToRExpandLeftUnknownPos";
-       private static final String TEST_NAME6 = 
"RewriteCTableToRExpandRightUnknownPos";
+       private static final String[] TEST_NAMES = new String[] {
+               "RewriteCTableToRExpandLeftPos",
+               "RewriteCTableToRExpandRightPos",
+               "RewriteCTableToRExpandLeftNeg",
+               "RewriteCTableToRExpandRightNeg",
+               "RewriteCTableToRExpandLeftUnknownPos",
+               "RewriteCTableToRExpandRightUnknownPos",
+               "RewriteCTableToRExpandRightVarMax"
+       };
        
        private static final String TEST_DIR = "functions/rewrite/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteCTableToRExpandTest.class.getSimpleName() + "/";
@@ -50,86 +53,108 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
+               for(int i=0; i<TEST_NAMES.length; i++)
+                       addTestConfiguration( TEST_NAMES[i],
+                               new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAMES[i], new String[] { "R" }) );
        }
 
        @Test
        public void testRewriteCTableRExpandLeftPositiveDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME1, CropType.CROP );
+               testRewriteCTableRExpand( 1, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandLeftPositiveDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME1, CropType.PAD );
+               testRewriteCTableRExpand( 1, CropType.PAD );
        }
        
        @Test
        public void testRewriteCTableRExpandRightPositiveDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME2, CropType.CROP );
+               testRewriteCTableRExpand( 2, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandRightPositiveDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME2, CropType.PAD );
+               testRewriteCTableRExpand( 2, CropType.PAD );
        }
        
        @Test
        public void testRewriteCTableRExpandLeftNegativeDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME3, CropType.CROP );
+               testRewriteCTableRExpand( 3, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandLeftNegativeDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME3, CropType.PAD );
+               testRewriteCTableRExpand( 3, CropType.PAD );
        }
        
        @Test
        public void testRewriteCTableRExpandRightNegativeDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME4, CropType.CROP );
+               testRewriteCTableRExpand( 4, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandRightNegativeDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME4, CropType.PAD );
+               testRewriteCTableRExpand( 4, CropType.PAD );
        }
        
        @Test
        public void testRewriteCTableRExpandLeftUnknownDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME5, CropType.CROP );
+               testRewriteCTableRExpand( 5, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandLeftUnknownDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME5, CropType.PAD );
+               testRewriteCTableRExpand( 5, CropType.PAD );
        }
        
        @Test
        public void testRewriteCTableRExpandRightUnknownDenseCrop()  {
-               testRewriteCTableRExpand( TEST_NAME6, CropType.CROP );
+               testRewriteCTableRExpand( 6, CropType.CROP );
        }
        
        @Test
        public void testRewriteCTableRExpandRightUnknownDensePad()  {
-               testRewriteCTableRExpand( TEST_NAME6, CropType.PAD );
+               testRewriteCTableRExpand( 6, CropType.PAD );
        }
        
-       private void testRewriteCTableRExpand( String testname, CropType type )
-       {       
+       @Test
+       public void testRewriteCTableRExpandVarMaxCropCP()  {
+               testRewriteCTableRExpand( 7, CropType.CROP, ExecMode.HYBRID );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandVarMaxPadCP()  {
+               testRewriteCTableRExpand( 7, CropType.PAD, ExecMode.HYBRID );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandVarMaxCropSP()  {
+               testRewriteCTableRExpand( 7, CropType.CROP, ExecMode.SPARK );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandVarMaxPadSP()  {
+               testRewriteCTableRExpand( 7, CropType.PAD, ExecMode.SPARK );
+       }
+       
+       private void testRewriteCTableRExpand( int test, CropType type ) {
+               testRewriteCTableRExpand(test, type, ExecMode.HYBRID);
+       }
+       
+       private void testRewriteCTableRExpand(int test, CropType type, ExecMode 
mode)
+       {
+               String testname = TEST_NAMES[test-1];
                TestConfiguration config = getTestConfiguration(testname);
                loadTestConfiguration(config);
 
                int outDim = maxVal + ((type==CropType.CROP) ? -7 : 7);
-               boolean unknownTests = ( testname.equals(TEST_NAME5) || 
testname.equals(TEST_NAME6) );
-                       
+               boolean unknownTests = (test >= 5);
                
                ExecMode platformOld = rtplatform;
-               if( unknownTests )
-                       rtplatform = ExecMode.SINGLE_NODE;
+               if( unknownTests & test != 7 )
+                       mode = ExecMode.SINGLE_NODE;
+               setExecMode(mode);
                
                try 
                {
@@ -148,21 +173,19 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
                        runTest(true, false, null, -1); 
                        
                        //compare output meta data
-                       boolean left = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME3) 
-                               || testname.equals(TEST_NAME5) || 
testname.equals(TEST_NAME6));
-                       boolean pos = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME2));
+                       boolean left = (test == 1 || test == 3 || test == 5 || 
test == 6 || test == 7);
+                       boolean pos = (test == 1 || test == 2);
                        int rrows = (left && pos) ? rows : outDim;
                        int rcols = (!left && pos) ? rows : outDim;
                        if( !unknownTests )
                                checkDMLMetaDataFile("R", new 
MatrixCharacteristics(rrows, rcols, 1, 1));
                        
                        //check for applied rewrite
-                       
Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) 
-                               || testname.equals(TEST_NAME2) || unknownTests),
+                       Assert.assertEquals(Boolean.valueOf(test==1 || test==2 
|| unknownTests),
                                
Boolean.valueOf(heavyHittersContainsSubString("rexpand")));
                }
                finally {
-                       rtplatform = platformOld;
+                       resetExecMode(platformOld);
                }
        }
 }
diff --git 
a/src/test/scripts/functions/rewrite/RewriteCTableToRExpandRightVarMax.dml 
b/src/test/scripts/functions/rewrite/RewriteCTableToRExpandRightVarMax.dml
new file mode 100644
index 0000000..ee3a0f8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteCTableToRExpandRightVarMax.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+N = sum(seq(1,10))/2
+A2 = rand(rows=N, cols=100, min=1, max=10);
+R = table(A2[,1], seq(1,nrow(A2)), N, nrow(A2));
+
+write(R, $3);

Reply via email to