This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new c6d7a52  [MINOR] Additional lineage parfor remote tests, and cleanups
c6d7a52 is described below

commit c6d7a52e2e4259fa62ba8e0b15cdfe1397baac0f
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Jun 23 22:46:05 2020 +0200

    [MINOR] Additional lineage parfor remote tests, and cleanups
    
    This patch adds msvm w/ remote_spark parfor workers to the test suite
    and fixes missing support for tak+ operators in the recompute-by-lineage
    utility.
---
 scripts/builtin/l2svm.dml                          |  2 +-
 .../sysds/hops/ipa/FunctionCallSizeInfo.java       |  9 ++--
 .../sysds/runtime/lineage/LineageItemUtils.java    | 25 ++++++---
 .../functions/lineage/LineageTraceParforTest.java  |  7 +++
 .../functions/lineage/LineageTraceParforMSVM.dml   | 61 ++++++++++++++++++++++
 5 files changed, 90 insertions(+), 14 deletions(-)

diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index 3e251ae..f411fb9 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -72,7 +72,7 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, 
Boolean intercept = FALSE
 
   # TODO make this a stop condition for l2svm instead of just printing.
   if(num_min + num_max != nrow(Y))
-    print("L2SVM: WARNING invalid number of labels in Y")
+    print("L2SVM: WARNING invalid number of labels in Y: "+num_min+" "+num_max)
 
   # Scale inputs to -1 for negative, and 1 for positive classification
   if(check_min != -1 | check_max != +1)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java 
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
index b349a5f..551ce98 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
@@ -233,14 +233,11 @@ public class FunctionCallSizeInfo
                                                                   &&  
h1.getDim1()==h2.getDim1() 
                                                                   &&  
h1.getDim2()==h2.getDim2()
                                                                   &&  
h1.getNnz()==h2.getNnz() );
-                                               //check literal values (equi 
value)
-                                               if( h1 instanceof LiteralOp ) {
-                                                       consistent &= (h2 
instanceof LiteralOp 
+                                               //check literal values (both 
needs to be literals and same value)
+                                               if( h1 instanceof LiteralOp || 
h2 instanceof LiteralOp ) {
+                                                       consistent &= (h1 
instanceof LiteralOp && h2 instanceof LiteralOp
                                                                && 
HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2));
                                                }
-                                               else if(h2 instanceof 
LiteralOp) {
-                                                       consistent = false; 
//h2 literal, but h1 not
-                                               }
                                        }
                                }
                                if( consistent )
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 467bbc9..e659025 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -278,6 +278,24 @@ public class LineageItemUtils {
                                                        
operands.put(item.getId(), aggunary);
                                                        break;
                                                }
+                                               case AggregateBinary: {
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       Hop aggbinary = 
HopRewriteUtils.createMatrixMultiply(input1, input2);
+                                                       
operands.put(item.getId(), aggbinary);
+                                                       break;
+                                               }
+                                               case AggregateTernary: {
+                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
+                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
+                                                       Hop input3 = 
operands.get(item.getInputs()[2].getId());
+                                                       Hop aggternary = 
HopRewriteUtils.createSum(
+                                                               
HopRewriteUtils.createBinary(
+                                                               
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+                                                               input3, 
OpOp2.MULT));
+                                                       
operands.put(item.getId(), aggternary);
+                                                       break;
+                                               }
                                                case Unary:
                                                case Builtin: {
                                                        Hop input = 
operands.get(item.getInputs()[0].getId());
@@ -308,13 +326,6 @@ public class LineageItemUtils {
                                                        
operands.put(item.getId(), binary);
                                                        break;
                                                }
-                                               case AggregateBinary: {
-                                                       Hop input1 = 
operands.get(item.getInputs()[0].getId());
-                                                       Hop input2 = 
operands.get(item.getInputs()[1].getId());
-                                                       Hop aggbinary = 
HopRewriteUtils.createMatrixMultiply(input1, input2);
-                                                       
operands.put(item.getId(), aggbinary);
-                                                       break;
-                                               }
                                                case Ternary: {
                                                        
operands.put(item.getId(), HopRewriteUtils.createTernary(
                                                                
operands.get(item.getInputs()[0].getId()), 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
index d100a4d..b3e0d73 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
@@ -46,6 +46,7 @@ public class LineageTraceParforTest extends AutomatedTestBase 
{
        protected static final String TEST_NAME3 = "LineageTraceParfor3"; 
//rand - matrix result - remote spark parfor
        protected static final String TEST_NAME4 = "LineageTraceParforSteplm"; 
//rand - steplm
        protected static final String TEST_NAME5 = "LineageTraceParforKmeans"; 
//rand - kmeans
+       protected static final String TEST_NAME6 = "LineageTraceParforMSVM"; 
//rand - msvm remote parfor
        
        protected String TEST_CLASS_DIR = TEST_DIR + 
LineageTraceParforTest.class.getSimpleName() + "/";
        
@@ -63,6 +64,7 @@ public class LineageTraceParforTest extends AutomatedTestBase 
{
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) );
                addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) );
                addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) );
+               addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"R"}) );
        }
        
        @Test
@@ -135,6 +137,11 @@ public class LineageTraceParforTest extends 
AutomatedTestBase {
                testLineageTraceParFor(32, TEST_NAME5);
        }
        
+       @Test
+       public void testLineageTraceMSVM_Remote64() {
+               testLineageTraceParFor(64, TEST_NAME6);
+       }
+       
        private void testLineageTraceParFor(int ncol, String testname) {
                try {
                        System.out.println("------------ BEGIN " + testname + 
"------------");
diff --git a/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml 
b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
new file mode 100644
index 0000000..23f39b0
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageTraceParforMSVM.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+msvm2 = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
+    Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, 
Boolean verbose = FALSE)
+  return(Matrix[Double] model)
+{
+  if(min(Y) < 0)
+    stop("MSVM: Invalid Y input, containing negative values")
+
+  if(verbose)
+    print("Running Multiclass-SVM")
+
+  num_rows_in_w = ncol(X)
+  if(intercept) {
+    num_rows_in_w = num_rows_in_w + 1
+  }
+
+  if(ncol(Y) > 1) 
+    Y = rowMaxs(Y * t(seq(1,ncol(Y))))
+
+  # Assuming number of classes to be max contained in Y
+  w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+
+  parfor(class in 1:max(Y), opt=CONSTRAINED, par=4, mode=REMOTE_SPARK) {
+    Y_local = 2 * (Y == class) - 1
+    w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
+        epsilon=epsilon, lambda=lambda, maxIterations=maxIterations, 
+        verbose= verbose, columnId=class)
+  }
+  
+  model = w
+}
+
+nclass = 10;
+
+X = rand(rows=$2, cols=$3, seed=1);
+y = rand(rows=$2, cols=1, min=0, max=nclass, seed=2);
+y = ceil(y);
+
+model = msvm2(X=X, Y=y, intercept=FALSE);
+                                                                       
+write(model, $1);

Reply via email to