This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e4f7d8  [SYSTEMDS-395] Cleanup SVM scripts, new confusionMatrix, 
msvmPredict
8e4f7d8 is described below

commit 8e4f7d82e0df9ce5c0634d7516de10fb262603ed
Author: Sebastian <[email protected]>
AuthorDate: Sun May 24 00:02:36 2020 +0200

    [SYSTEMDS-395] Cleanup SVM scripts, new confusionMatrix, msvmPredict
    
    - ConfusionMatrix
    - msvmPredict
    
    Make confusion matrixes, based on Predictions and labels.
    It returns two matrixes:
    - A count matrix, containing counts of each case in the matrix.
    - An avg matrix, returning the accuracy of each class, and thereby how
      the percentage distribution across labels, aka the percentage
    confusion.
    
    msvmPredict applies the trained msvm model and returns
    - Y_hat, the output from the model, that is the raw output from the
    model
    - Y, the row max of the raw output, which is the highest value
    predictions.
    
    Furthermore some consistency changes in L2SVM and MSVM.
    
    Closes #910.
---
 docs/Tasks.txt                                     |   1 +
 scripts/algorithms/l2-svm.dml                      | 114 ++----------
 scripts/builtin/confusionMatrix.dml                |  62 +++++++
 scripts/builtin/kmeans.dml                         |  22 +--
 scripts/builtin/l2svm.dml                          |  97 +++++-----
 scripts/builtin/msvm.dml                           |  42 ++---
 scripts/builtin/msvmPredict.dml                    |  53 ++++++
 scripts/builtin/multiLogRegPredict.dml             |  17 +-
 .../java/org/apache/sysds/common/Builtins.java     |   2 +
 .../builtin/BuiltinConfusionMatrixTest.java        | 195 +++++++++++++++++++++
 .../builtin/BuiltinMulticlassSVMPredictTest.java   | 186 ++++++++++++++++++++
 .../builtin/BuiltinMulticlassSVMTest.java          |  54 +++---
 .../builtin/{l2svm.dml => confusionMatrix.dml}     |   8 +-
 src/test/scripts/functions/builtin/l2svm.dml       |   2 +-
 src/test/scripts/functions/builtin/multisvm.R      |  10 +-
 src/test/scripts/functions/builtin/multisvm.dml    |   4 +-
 .../builtin/{l2svm.dml => multisvmPredict.dml}     |   7 +-
 .../functions/federated/FederatedL2SVMTest.dml     |   2 +-
 .../federated/FederatedL2SVMTestReference.dml      |   2 +-
 19 files changed, 653 insertions(+), 227 deletions(-)

diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 1196566..3c9782f 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -304,6 +304,7 @@ SYSTEMDS-390 New Builtin Functions IV
  * 392 Builtin function for missing value imputation via FDs          OK
  * 393 Builtin to find Connected Components of a graph                OK
  * 394 Builtin for one-hot encoding of matrix (not frame), see table  OK
+ * 395 SVM rework and utils (confusionMatrix, msvmPredict)            OK
 
 Others:
  * Break append instruction to cbind and rbind 
diff --git a/scripts/algorithms/l2-svm.dml b/scripts/algorithms/l2-svm.dml
index 4cbcdb5..04d6524 100644
--- a/scripts/algorithms/l2-svm.dml
+++ b/scripts/algorithms/l2-svm.dml
@@ -21,10 +21,6 @@
 
 # Implements binary-class SVM with squared slack variables
 #
-# Example Usage:
-# Assume L2SVM_HOME is set to the home of the dml script
-# Assume input and output directories are on hdfs as INPUT_DIR and OUTPUT_DIR
-# Assume epsilon = 0.001, lambda = 1, maxiterations = 100
 #
 # INPUT PARAMETERS:
 # 
---------------------------------------------------------------------------------------------
@@ -40,111 +36,31 @@
 # maxiter   Int     100         Maximum number of conjugate gradient iterations
 # model     String  ---         Location to write model
 # fmt       String  "text"      The output format of the output, such as 
"text" or "csv"
-# Log       String  ---         [OPTIONAL] Location to write the log file
 # 
---------------------------------------------------------------------------------------------
 
-# hadoop jar SystemDS.jar -f $L2SVM_HOME/l2-svm.dml -nvargs X=$INPUT_DIR/X 
Y=$INPUT_DIR/Y \
-#   icpt=0 tol=0.001 reg=1 maxiter=100 model=$OUPUT_DIR/w Log=$OUTPUT_DIR/Log 
fmt="text"
-#
+# Example Execution:
+# systemds -f $SYSTEMDS_ROOT/scripts/algorithms/l2-svm.dml \
+#   -nvargs X=$INPUT_DIR/X Y=$INPUT_DIR/Y \
+#   icpt=FALSE tol=0.001 reg=1 maxiter=100 \
+#   model=$OUPUT_DIR/w fmt="text"
+
 # Note about inputs: 
 # Assumes that labels (entries in Y) are set to either -1 or +1 or 
non-negative integers
 
 fmt = ifdef($fmt, "text")
-intercept = ifdef($icpt, 0)
+intercept = ifdef($icpt, FALSE)
 epsilon = ifdef($tol, 0.001)
 lambda = ifdef($reg, 1.0)
-maxiterations = ifdef($maxiter, 100)
+maxIterations = ifdef($maxiter, 100)
+verbose = ifdef($verbose, FALSE)
 
 X = read($X)
 Y = read($Y)
 
-#check input parameter assertions
-if(nrow(X) < 2)
-  stop("Stopping due to invalid inputs: Not possible to learn a binary class 
classifier without at least 2 rows")
-if(intercept != 0 & intercept != 1)
-  stop("Stopping due to invalid argument: Currently supported intercept 
options are 0 and 1")
-if(epsilon < 0)
-  stop("Stopping due to invalid argument: Tolerance (tol) must be 
non-negative")
-if(lambda < 0)
-  stop("Stopping due to invalid argument: Regularization constant (reg) must 
be non-negative")
-if(maxiterations < 1)
-  stop("Stopping due to invalid argument: Maximum iterations should be a 
positive integer")
-
-#check input lables and transform into -1/1
-check_min = min(Y)
-check_max = max(Y)
-num_min = sum(Y == check_min)
-num_max = sum(Y == check_max)
-if(check_min == check_max)
-  stop("Stopping due to invalid inputs: Y seems to contain exactly one label")
-if(num_min + num_max != nrow(Y))
-  stop("Stopping due to invalid inputs: Y seems to contain more than 2 labels")
-if(check_min != -1 | check_max != 1)
-  Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
-
-positive_label = check_max
-negative_label = check_min
-num_samples = nrow(X)
-dimensions = ncol(X)
-num_rows_in_w = dimensions
-
-if (intercept == 1) {
-  ones = matrix(1, rows=num_samples, cols=1)
-  X = cbind(X, ones);
-  num_rows_in_w += 1
-}
-
-w = matrix(0, num_rows_in_w, 1)
-Xw = matrix(0, rows=nrow(X), cols=1)
-g_old = t(X) %*% Y
-s = g_old
-
-debug_str = "# Iter, Obj"
-iter = 0
-continue = TRUE
-
-while(continue & iter < maxiterations) {
-  # minimizing primal obj along direction s
-  step_sz = 0
-  Xd = X %*% s
-  wd = lambda * sum(w * s)
-  dd = lambda * sum(s * s)
-  
-  continue1 = TRUE
-  while(continue1) {
-    tmp_Xw = Xw + step_sz*Xd
-    out = 1 - Y * tmp_Xw
-    sv = out > 0
-    out = out * sv
-    g = wd + step_sz*dd - sum(out * Y * Xd)
-    h = dd + sum(Xd * sv * Xd)
-    step_sz = step_sz - g/h
-    continue1 = (g*g/h >= 0.0000000001);
-  }
-
-  #update weights
-  w += step_sz * s
-  Xw += step_sz * Xd
-  
-  out = 1 - Y * Xw
-  sv = out > 0
-  out = sv * out
-  obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
-  g_new = t(X) %*% (out * Y) - lambda * w
-  
-  print("ITER " + iter + ": OBJ=" + obj)
-  debug_str = append(debug_str, iter + "," + obj)
-  
-  tmp = sum(s * g_old)
-  
-  #non-linear CG step
-  be = sum(g_new * g_new)/sum(g_old * g_old)
-  s = be * s + g_new
-  g_old = g_new
-  
-  continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
-  iter = iter + 1
-}
+w = l2svm(X=X, Y=Y, intercept=intercept, 
+  epsilon=epsilon, lambda=labmda, 
+  maxIterations=maxIterations,
+  verbose=verbose)
 
 extra_model_params = matrix(0, 4, 1)
 extra_model_params[1,1] = positive_label
@@ -154,7 +70,3 @@ extra_model_params[4,1] = dimensions
 
 w = rbind(w, extra_model_params)
 write(w, $model, format=fmt)
-
-logFile = $Log
-if(logFile != " ")
-  write(debug_str, logFile)
diff --git a/scripts/builtin/confusionMatrix.dml 
b/scripts/builtin/confusionMatrix.dml
new file mode 100644
index 0000000..0a62182
--- /dev/null
+++ b/scripts/builtin/confusionMatrix.dml
@@ -0,0 +1,62 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# INPUT PARAMETERS:
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# P               Double  ---         vector of Predictions
+# Y               Double  ---         vector of Golden standard One Hot Encoded
+# 
---------------------------------------------------------------------------------------------
+# OUTPUT:
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# ConfusionSum    Double  ---         The Confusion Matrix Sums of 
classifications
+# ConfusionAvg    Double  ---         The Confusion Matrix averages of each 
true class
+
+# Output is like:
+#                   True Labels
+#                     1    2
+#                 1   TP | FP
+#   Predictions      ----+----
+#                 2   FN | TN
+# 
+# TP = True Positives
+# FP = False Positives
+# FN = False Negatives
+# TN = True Negatives
+
+m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y)
+  return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg)
+{
+  if(ncol(P) > 1  | ncol(Y) > 1)
+    stop("CONFUSION MATRIX: Invalid input number of cols should be 1 in both P 
["+ncol(P)+"] and Y ["+ncol(Y)+"]")
+  if(nrow(P) != nrow(Y))
+    stop("CONFUSION MATRIX: The number of rows have to be equal in both P 
["+nrow(P)+"] and Y ["+nrow(Y)+"]")
+  if(min(P) < 1 | min(Y) < 1)
+    stop("CONFUSION MATRIX: All Values in P and Y should be abore or equal to 
1, min(P):" + min(P) + " min(Y):" + min(Y) )
+
+  dim = max(max(Y),max(P))
+  confusionSum = table(P, Y,  dim, dim)
+  # max to avoid devision by 0, in case a colum contain no entries.
+  confusionAvg = confusionSum / max(1,colSums(confusionSum))
+}
diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 96591c6..75e5bc7 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -23,21 +23,23 @@
 #
 # INPUT PARAMETERS:
 # ----------------------------------------------------------------------------
-# NAME  TYPE   DEFAULT  MEANING
+# NAME                              TYPE      DEFAULT  MEANING
 # ----------------------------------------------------------------------------
-# X                                 String   ---    Location to read matrix X 
with the input data records
-# k                                 Int      ---    Number of centroids
-# runs                              Int       10    Number of runs (with 
different initial centroids)
-# max_iter                          Int     1000    Maximum number of 
iterations per run
-# eps                               Double 0.000001 Tolerance (epsilon) for 
WCSS change ratio
-# is_verbose                        Boolean FALSE   do not print per-iteration 
stats
-# avg_sample_size_per_centroid      Int       50    Average number of records 
per centroid in data samples
+# X                                 Double    ---      The input Matrix to do 
KMeans on.
+# k                                 Int       ---      Number of centroids
+# runs                              Int       10       Number of runs (with 
different initial centroids)
+# max_iter                          Int       1000     Maximum number of 
iterations per run
+# eps                               Double    0.000001 Tolerance (epsilon) for 
WCSS change ratio
+# is_verbose                        Boolean   FALSE    do not print 
per-iteration stats
+# avg_sample_size_per_centroid      Int       50       Average number of 
records per centroid in data samples
 #
 #
 # RETURN VALUES
 # ----------------------------------------------------------------------------
-# Y     String  "Y.mtx" Location to store the mapping of records to centroids
-# C     String  "C.mtx" Location to store the output matrix with the centroids
+# NAME     TYPE      DEFAULT  MEANING
+# ----------------------------------------------------------------------------
+# Y        String    "Y.mtx"  The mapping of records to centroids
+# C        String    "C.mtx"  The output matrix with the centroids
 # ----------------------------------------------------------------------------
 
 
diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index fac0be9..3e251ae 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -27,12 +27,15 @@
 # NAME            TYPE    DEFAULT     MEANING
 # 
---------------------------------------------------------------------------------------------
 # X               Double  ---         matrix X of feature vectors
-# Y               Double  ---         matrix Y of class labels
+# Y               Double  ---         matrix Y of class labels have to be a 
single column
 # intercept       Boolean False       No Intercept ( If set to TRUE then a 
constant bias column is added to X)
 # epsilon         Double  0.001       Procedure terminates early if the 
reduction in objective function 
 #                      value is less than epsilon (tolerance) times the 
initial objective function value.
 # lambda          Double  1.0         Regularization parameter (lambda) for L2 
regularization
 # maxiterations   Int     100         Maximum number of conjugate gradient 
iterations
+# verbose         Boolean False       Set to true if one wants print 
statements updating on loss.
+# column_id       Int     -1          The column Id used if one wants to add a 
ID to the print statement, 
+#                                     Specificly usefull when L2SVM is used in 
MSVM.
 # 
---------------------------------------------------------------------------------------------
  
 
@@ -44,65 +47,70 @@
 
 
 m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = 
FALSE,
-Double epsilon = 0.001, Double lambda = 1, Integer maxiterations = 100, 
Boolean verbose = FALSE)
-    return(Matrix[Double] model)
+    Double epsilon = 0.001, Double lambda = 1, Integer maxIterations = 100, 
+    Boolean verbose = FALSE, Integer columnId = -1)
+  return(Matrix[Double] model)
 {
-
   #check input parameter assertions
   if(nrow(X) < 2)
-    stop("Stopping due to invalid inputs: Not possible to learn a binary class 
classifier without at least 2 rows")
+    stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a 
binary class classifier without at least 2 rows")
   if(epsilon < 0)
-    stop("Stopping due to invalid argument: Tolerance (tol) must be 
non-negative")
+    stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be 
non-negative")
   if(lambda < 0)
-    stop("Stopping due to invalid argument: Regularization constant (reg) must 
be non-negative")
-  if(maxiterations < 1)
-    stop("Stopping due to invalid argument: Maximum iterations should be a 
positive integer")
-  
+    stop("L2SVM: Stopping due to invalid argument: Regularization constant 
(reg) must be non-negative")
+  if(maxIterations < 1)
+    stop("L2SVM: Stopping due to invalid argument: Maximum iterations should 
be a positive integer")
+  if(ncol(Y) < 1)
+    stop("L2SVM: Stopping due to invalid multiple label columns, maybe use 
MSVM instead?")
+
   #check input lables and transform into -1/1
   check_min = min(Y)
   check_max = max(Y)
-  
+
   num_min = sum(Y == check_min)
   num_max = sum(Y == check_max)
-  
-  
-  if(num_min + num_max != nrow(Y)) print("please check Y, it should contain 
only 2 labels")
-  else{
-    if(check_min != -1 | check_max != +1)
-      Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
-  }
-  
-  if(verbose) print('running L2-SVM ');
-  
+
+  # TODO make this a stop condition for l2svm instead of just printing.
+  if(num_min + num_max != nrow(Y))
+    print("L2SVM: WARNING invalid number of labels in Y")
+
+  # Scale inputs to -1 for negative, and 1 for positive classification
+  if(check_min != -1 | check_max != +1)
+    Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
+    
+  # If column_id is -1 then we assume that the fundamental algorithm is MSVM, 
+  # Therefore don't print message.
+  if(verbose & columnId == -1)
+    print('Running L2-SVM ')
+
   num_samples = nrow(X)
-  dimensions = ncol(X)
+  num_classes = ncol(Y)
   
+  # Add Bias 
+  num_rows_in_w = ncol(X)
   if (intercept) {
     ones  = matrix(1, rows=num_samples, cols=1)
     X = cbind(X, ones);
+    num_rows_in_w += 1
   }
   
-  num_rows_in_w = dimensions
-  if(intercept){
-    num_rows_in_w = num_rows_in_w + 1
-  }
   w = matrix(0, rows=num_rows_in_w, cols=1)
-  
+
   g_old = t(X) %*% Y
   s = g_old
-  
+
   Xw = matrix(0, rows=nrow(X), cols=1)
-  
+
   iter = 0
-  continue = 1
-  while(continue == 1 & iter < maxiterations)  {
+  continue = TRUE
+  while(continue & iter < maxIterations)  {
     # minimizing primal obj along direction s
     step_sz = 0
     Xd = X %*% s
     wd = lambda * sum(w * s)
     dd = lambda * sum(s * s)
-    continue1 = 1
-    while(continue1 == 1){
+    continue1 = TRUE
+    while(continue1){
       tmp_Xw = Xw + step_sz*Xd
       out = 1 - Y * (tmp_Xw)
       sv = (out > 0)
@@ -112,32 +120,31 @@ Double epsilon = 0.001, Double lambda = 1, Integer 
maxiterations = 100, Boolean
       step_sz = step_sz - g/h
       continue1 = (g*g/h >= epsilon)
     }
-  
+
     #update weights
     w = w + step_sz*s
     Xw = Xw + step_sz*Xd
-  
+
     out = 1 - Y * Xw
     sv = (out > 0)
     out = sv * out
     obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
     g_new = t(X) %*% (out * Y) - lambda * w
-  
-    
-    if(verbose) print("Iter, Obj "+ iter + ", "+obj)
-  
-    tmp = sum(s * g_old)
-    if(step_sz*tmp < epsilon*obj){
-      continue = 0
+
+    if(verbose) {
+      colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+      print("Iter:" + toString(iter) + colstr + " Obj:" + obj)
     }
-  
+
+    tmp = sum(s * g_old)
+    continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
     #non-linear CG step
     be = sum(g_new * g_new)/sum(g_old * g_old)
     s = be * s + g_new
     g_old = g_new
-  
+
     iter = iter + 1
   }
-
   model = w
 }
diff --git a/scripts/builtin/msvm.dml b/scripts/builtin/msvm.dml
index 86046c6..63dfa5e 100644
--- a/scripts/builtin/msvm.dml
+++ b/scripts/builtin/msvm.dml
@@ -33,7 +33,8 @@
 # epsilon         Double  0.001       Procedure terminates early if the 
reduction in objective function 
 #                                     value is less than epsilon (tolerance) 
times the initial objective function value.
 # lambda          Double  1.0         Regularization parameter (lambda) for L2 
regularization
-# maxiterations   Int     100         Maximum number of conjugate gradient 
iterations
+# maxIterations   Int     100         Maximum number of conjugate gradient 
iterations
+# verbose         Boolean False       Set to true to print while training.
 # 
---------------------------------------------------------------------------------------------
  
 #Output(s)
@@ -42,34 +43,33 @@
 # 
---------------------------------------------------------------------------------------------
 # model           Double   ---        model matrix
 
-m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = 
FALSE, Integer num_classes =10,
-                  Double epsilon = 0.001, Double lambda = 1.0, Integer 
max_iterations = 100, Boolean verbose = FALSE)
+m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = 
FALSE,
+    Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, 
Boolean verbose = FALSE)
   return(Matrix[Double] model)
 {
-  if(verbose)
-    print("Built-in Multiclass-SVM started")
+  if(min(Y) < 0)
+    stop("MSVM: Invalid Y input, containing negative values")
 
-  num_samples = nrow(X)
-  num_features = ncol(X)
+  if(verbose)
+    print("Running Multiclass-SVM")
 
-  num_rows_in_w = num_features
+  num_rows_in_w = ncol(X)
   if(intercept) {
     num_rows_in_w = num_rows_in_w + 1
   }
-  
-  w = matrix(0, rows=num_rows_in_w, cols=num_classes)
 
-  parfor(iter_class in 1:num_classes) {
-    Y_local = 2 * (Y == iter_class) - 1
-    if(verbose) {
-      print("iter class: " + iter_class)
-      print("y local: " + toString(Y_local))
-    }
-    w[,iter_class] = l2svm(X=X, Y=Y_local, intercept=intercept,
-      epsilon=epsilon, lambda=lambda, maxiterations=max_iterations)
-  }
+  if(ncol(Y) > 1) 
+    Y = rowMaxs(Y * t(seq(1,ncol(Y))))
 
+  # Assuming number of classes to be max contained in Y
+  w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+
+  parfor(class in 1:max(Y)) {
+    Y_local = 2 * (Y == class) - 1
+    w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
+        epsilon=epsilon, lambda=lambda, maxIterations=maxIterations, 
+        verbose= verbose, columnId=class)
+  }
+  
   model = w
-  if (verbose)
-    print("model["+iter_class+"]: " + toString(model))
 }
diff --git a/scripts/builtin/msvmPredict.dml b/scripts/builtin/msvmPredict.dml
new file mode 100644
index 0000000..2b4fa42
--- /dev/null
+++ b/scripts/builtin/msvmPredict.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This Scripts helps in applying an trained MSVM
+
+# INPUT PARAMETERS:
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# X               Double  ---         matrix X of feature vectors to classify
+# W               Double  ---         matrix of the trained variables
+# 
---------------------------------------------------------------------------------------------
+# OUTPUT:
+# 
---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# 
---------------------------------------------------------------------------------------------
+# Y^              Double  ---         Classification Labels Raw, meaning not 
modified to clean
+#                                     Labeles of 1's and -1's
+# Y               Double  ---         Classification Labels Maxed to ones and 
zeros.
+
+m_msvmPredict = function(Matrix[Double] X, Matrix[Double] W)
+  return(Matrix[Double] YRaw, Matrix[Double] Y)
+{
+  if(ncol(X) != nrow(W)){
+    if(ncol(X) + 1 != nrow(W)){
+      stop("MSVM Predict: Invalid shape of W ["+ncol(W)+","+nrow(W)+"] or X 
["+ncol(X)+","+nrow(X)+"]")
+    }
+    YRaw = X %*% W[1:ncol(X),] + W[ncol(X)+1,]
+    Y = rowIndexMax(YRaw)
+  }
+  else{
+    YRaw = X %*% W
+    Y = rowIndexMax(YRaw)
+  }
+}
diff --git a/scripts/builtin/multiLogRegPredict.dml 
b/scripts/builtin/multiLogRegPredict.dml
index 5ea9a04..213f7dd 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -47,17 +47,17 @@ return(Matrix[Double] M, Matrix[Double] predicted_Y, Double 
accuracy)
   if(min(Y) <= 0)
     stop("class labels should be greater than zero")
     
-  num_records  = nrow (X);
-  num_features = ncol (X);
-  beta =  B [1 : ncol (X),  ];
-  intercept = B [nrow(B),  ];
+  num_records = nrow(X);
+  num_features = ncol(X);
+  beta = B[1:ncol(X), ];
+  intercept = B[nrow(B), ];
 
-  if (nrow (B) == ncol (X))
+  if (nrow(B) == ncol(X))
     intercept = 0.0 * intercept; 
   else
     num_features = num_features + 1;
 
-  ones_rec = matrix (1, rows = num_records, cols = 1);
+  ones_rec = matrix(1, rows = num_records, cols = 1);
   linear_terms = X %*% beta + ones_rec %*% intercept;
 
   M = probabilities(linear_terms); # compute the probablitites on unknown data
@@ -67,10 +67,7 @@ return(Matrix[Double] M, Matrix[Double] predicted_Y, Double 
accuracy)
     accuracy = sum((predicted_Y - Y) == 0) / num_records * 100;
   
   if(verbose)
-  {
-  acc_str = "Accuracy (%): " + accuracy
-  print(acc_str)
-  }
+    print("Accuracy (%): " + accuracy);
 }
 
 probabilities = function (Matrix[double] linear_terms)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index b09fc63..7345077 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -81,6 +81,7 @@ public enum Builtins {
        CUMPROD("cumprod", false),
        CUMSUM("cumsum", false),
        CUMSUMPROD("cumsumprod", false),
+       CONFUSIONMATRIX("confusionMatrix", true),
        DETECTSCHEMA("detectSchema", false),
        DIAG("diag", false),
        DISCOVER_FD("discoverFD", true),
@@ -127,6 +128,7 @@ public enum Builtins {
        MEDIAN("median", false),
        MOMENT("moment", "centralMoment", false),
        MSVM("msvm", true),
+       MSVMPREDICT("msvmPredict", true),
        MULTILOGREG("multiLogReg", true),
        MULTILOGREGPREDICT("multiLogRegPredict", true),
        NCOL("ncol", false),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
new file mode 100644
index 0000000..bd5ccca
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import static org.junit.Assert.fail;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinConfusionMatrixTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "confusionMatrix";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinConfusionMatrixTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+       }
+
+       public double eps = 0.00001;
+
+       @Test
+       public void test_01() {
+               double[][] y;
+               double[][] p;
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+
+               // Classification is 100% accurate in all classes if Y == P.
+               y = TestUtils.round(getRandomMatrix(1000, 1, 1, 2, 1.0, 7));
+               p = y;
+
+               res.put(new CellIndex(1, 1), 1.0);
+               res.put(new CellIndex(2, 2), 1.0);
+
+               for(LopProperties.ExecType ex : new ExecType[] 
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+                       runConfusionMatrixTest(y, p, res, ex);
+               }
+       }
+
+       @Test
+       public void test_02() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               res.put(new CellIndex(2, 2), 1.0);
+               runConfusionMatrixTest(new double[][] {{2}}, new double[][] 
{{2}}, res, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void test_03() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               res.put(new CellIndex(2, 1), 1.0);
+               runConfusionMatrixTest(new double[][] {{1}}, new double[][] 
{{2}}, res, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void test_04() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               res.put(new CellIndex(6, 1), 1.0);
+               runConfusionMatrixTest(new double[][] {{1}}, new double[][] 
{{6}}, res, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void test_05() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               res.put(new CellIndex(1, 9), 1.0);
+               runConfusionMatrixTest(new double[][] {{9}}, new double[][] 
{{1}}, res, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void test_06() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               double[][] y = new double[][] {{1}, {1}, {1}, {1}};
+               double[][] p = new double[][] {{1}, {2}, {3}, {4}};
+               res.put(new CellIndex(1, 1), 0.25);
+               res.put(new CellIndex(2, 1), 0.25);
+               res.put(new CellIndex(3, 1), 0.25);
+               res.put(new CellIndex(4, 1), 0.25);
+               runConfusionMatrixTest(y, p, res, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void test_07() {
+               HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+               double[][] y = new double[][] {{1}, {2}, {3}, {4}};
+               double[][] p = new double[][] {{1}, {1}, {1}, {1}};
+               res.put(new CellIndex(1, 1), 1.0);
+               res.put(new CellIndex(1, 2), 1.0);
+               res.put(new CellIndex(1, 3), 1.0);
+               res.put(new CellIndex(1, 4), 1.0);
+               runConfusionMatrixTest(y, p, res, LopProperties.ExecType.CP);
+       }
+
+       private void runConfusionMatrixTest(double[][] y, double[][] p, 
HashMap<MatrixValue.CellIndex, Double> res,
+               ExecType instType) {
+               ExecMode platformOld = setExecMode(instType);
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-nvargs", "P=" + 
input("P"), "Y=" + input("Y"), "out_file=" + output("B")};
+                       writeInputMatrixWithMTD("P", p, false);
+                       writeInputMatrixWithMTD("Y", y, false);
+                       runTest(true, false, null, -1);
+
+                       HashMap<MatrixValue.CellIndex, Double> dmlResult = 
readDMLMatrixFromHDFS("B");
+                       TestUtils.compareMatrices(dmlResult, res, eps, 
"DML_Result", "Expected");
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       // TODO Future, does it make sense to save an empty matrix, since we 
have ways to make an empty matrix?
+       // @Test
+       // public void test_invalid_01(){
+       // // Test if the script fails with input containing no values.
+       // runConfusionMatrixExceptionTest(new double[][]{}, new double[][]{});
+       // }
+
+       @Test
+       public void test_invalid_02() {
+               // Test if the script fails with input contain multiple columns
+               runConfusionMatrixExceptionTest(new double[][] {{1, 2}}, new 
double[][] {{1, 2}});
+               runConfusionMatrixExceptionTest(new double[][] {{1}}, new 
double[][] {{1, 2}});
+               runConfusionMatrixExceptionTest(new double[][] {{1, 2}}, new 
double[][] {{1}});
+       }
+
+       @Test
+       public void test_invalid_03() {
+               // Test if the script fails with input contains different 
amount of rows
+               runConfusionMatrixExceptionTest(new double[][] {{1}, {1}}, new 
double[][] {{1}});
+               runConfusionMatrixExceptionTest(new double[][] {{1}}, new 
double[][] {{1}, {1}});
+       }
+
+       private void runConfusionMatrixExceptionTest(double[][] y, double[][] 
p) {
+               ExecMode platformOld = setExecMode(LopProperties.ExecType.CP);
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-nvargs", "P=" + 
input("P"), "Y=" + input("Y"), "out_file=" + output("B")};
+                       writeInputMatrixWithMTD("P", p, false);
+                       writeInputMatrixWithMTD("Y", y, false);
+
+                       // TODO make stop throw exception instead
+                       // 
https://issues.apache.org/jira/projects/SYSTEMML/issues/SYSTEMML-2540
+                       // runTest(true, true, DMLScriptException.class, -1);
+
+                       // Verify that the outputfile is not existing!
+                       runTest(true, false, null, -1);
+
+                       try {
+                               readDMLMatrixFromHDFS("B");
+                               fail("File should not have been written");
+                       }
+                       catch(AssertionError e) {
+                               // exception expected
+                       }
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
new file mode 100644
index 0000000..985771d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import static org.junit.Assert.fail;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinMulticlassSVMPredictTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "multisvmPredict";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinConfusionMatrixTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"YRaw", "Y"}));
+       }
+
+       public double eps = 0.00001;
+
+       @Test
+       public void test_01() {
+
+               double[][] x = new double[][] {{0.4, 0.0}, {0.0, 0.2}};
+               double[][] w = new double[][] {{1.0, 0.0}, {0.0, 1.0}};
+
+               HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+               res_Y.put(new CellIndex(1, 1), 1.0);
+               res_Y.put(new CellIndex(2, 1), 2.0);
+
+               HashMap<MatrixValue.CellIndex, Double> res_YRaw = new 
HashMap<>();
+               res_YRaw.put(new CellIndex(1, 1), 0.4);
+               res_YRaw.put(new CellIndex(2, 2), 0.2);
+
+               for(LopProperties.ExecType ex : new ExecType[] 
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+                       runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+               }
+       }
+
+       @Test
+       public void test_02() {
+               double[][] x = new double[][] {{0.4, 0.1}};
+               double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0}};
+
+               HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+               res_Y.put(new CellIndex(1, 1), 1.0);
+               
+               HashMap<MatrixValue.CellIndex, Double> res_YRaw = new 
HashMap<>();
+               res_YRaw.put(new CellIndex(1, 1), 0.42);
+               res_YRaw.put(new CellIndex(1, 2), 0.3);
+
+               for(LopProperties.ExecType ex : new ExecType[] 
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+                       runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+               }
+       }
+
+       @Test
+       public void test_03() {
+               // Add bios column
+               double[][] x = new double[][] {{0.4, 0.1}};
+               double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0}, 
{1.0,0.5}};
+
+               HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+               res_Y.put(new CellIndex(1, 1), 1.0);
+
+               HashMap<MatrixValue.CellIndex, Double> res_YRaw = new 
HashMap<>();
+               res_YRaw.put(new CellIndex(1, 1), 1.42);
+               res_YRaw.put(new CellIndex(1, 2), 0.8);
+
+               for(LopProperties.ExecType ex : new ExecType[] 
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+                       runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+               }
+       }
+
+       private void runMSVMPredict(double[][] x, double[][] w, 
HashMap<MatrixValue.CellIndex, Double> YRaw,
+               HashMap<MatrixValue.CellIndex, Double> Y, ExecType instType) {
+               ExecMode platformOld = setExecMode(instType);
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-nvargs", "X=" + 
input("X"), "W=" + input("W"), "YRaw=" + output("YRaw"),
+                               "Y=" + output("Y")};
+                       writeInputMatrixWithMTD("X", x, false);
+                       writeInputMatrixWithMTD("W", w, false);
+                       runTest(true, false, null, -1);
+
+                       HashMap<MatrixValue.CellIndex, Double> YRaw_res = 
readDMLMatrixFromHDFS("YRaw");
+                       HashMap<MatrixValue.CellIndex, Double> Y_res = 
readDMLMatrixFromHDFS("Y");
+
+                       TestUtils.compareMatrices(YRaw_res, YRaw, eps, 
"DML_Result", "Expected");
+                       TestUtils.compareMatrices(Y_res, Y, eps, "DML_Result", 
"Expected");
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       @Test
+       public void test_invalid_01() {
+               // Test if the script fails with input contain incorrect number 
of columns
+               double[][] x = new double[][] {{1, 2, 3}};
+               double[][] w = new double[][] {{1, -1}, {-1, 1}};
+               runMSVMPredictionExceptionTest(x, w);
+       }
+
+       @Test
+       public void test_invalid_02() {
+               // Test if the script fails with input contain incorrect number 
of rows vs columns
+               double[][] x = new double[][] {{1, 2, 3}};
+               double[][] w = new double[][] {{1, -1, 1, 3 ,3}, {-1, 1, 1, 1 
,12}};
+               runMSVMPredictionExceptionTest(x, w);
+       }
+
+       @Test
+       public void test_invalid_03() {
+               // Add one column more than the bios column.
+               double[][] x = new double[][] {{1, 2, 3}};
+               double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0}, 
{1.0,0.5},{1.0,0.5},{1.0,0.5}};
+               runMSVMPredictionExceptionTest(x, w);
+       }
+
+       private void runMSVMPredictionExceptionTest(double[][] x, double[][] w) 
{
+               ExecMode platformOld = setExecMode(LopProperties.ExecType.CP);
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-nvargs", "X=" + 
input("X"), "W=" + input("W"), "YRaw=" + output("YRaw"),
+                               "Y=" + output("Y")};
+                       writeInputMatrixWithMTD("X", x, false);
+                       writeInputMatrixWithMTD("W", w, false);
+
+                       // TODO make stop throw exception instead
+                       // 
https://issues.apache.org/jira/projects/SYSTEMML/issues/SYSTEMML-2540
+                       // runTest(true, true, DMLScriptException.class, -1);
+
+                       // Verify that the outputfile is not existing!
+                       runTest(true, false, null, -1);
+
+                       try {
+                               readDMLMatrixFromHDFS("YRaw");
+                               fail("File should not have been written");
+                       }
+                       catch(AssertionError e) {
+                               // exception expected
+                       }
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
index b46ba3f..56ceda1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
@@ -31,8 +31,7 @@ import org.apache.sysds.test.TestUtils;
 
 import java.util.HashMap;
 
-public class BuiltinMulticlassSVMTest extends AutomatedTestBase
-{
+public class BuiltinMulticlassSVMTest extends AutomatedTestBase {
        private final static String TEST_NAME = "multisvm";
        private final static String TEST_DIR = "functions/builtin/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinMulticlassSVMTest.class.getSimpleName() + "/";
@@ -43,72 +42,79 @@ public class BuiltinMulticlassSVMTest extends 
AutomatedTestBase
        private final static double spSparse = 0.01;
        private final static double spDense = 0.7;
        private final static int max_iter = 10;
-       private final static int num_classes = 10;
 
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"C"}));
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
        }
 
        @Test
        public void testMSVMDense() {
-               runMSVMTest(false, false, num_classes, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
+               runMSVMTest(false, false, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
        }
+
        @Test
        public void testMSVMSparse() {
-               runMSVMTest(true, false, num_classes, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
+               runMSVMTest(true, false, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
        }
+
        @Test
        public void testMSVMInterceptSpark() {
-               runMSVMTest(true,true, num_classes, eps, 1.0, max_iter, 
LopProperties.ExecType.SPARK);
+               runMSVMTest(true, true, eps, 1.0, max_iter, 
LopProperties.ExecType.SPARK);
        }
 
        @Test
        public void testMSVMSparseLambda2() {
-               runMSVMTest(true,true, num_classes, eps,2.0, max_iter, 
LopProperties.ExecType.CP);
+               runMSVMTest(true, true, eps, 2.0, max_iter, 
LopProperties.ExecType.CP);
        }
+
        @Test
        public void testMSVMSparseLambda100CP() {
-               runMSVMTest(true,true, num_classes, 1, 100, max_iter, 
LopProperties.ExecType.CP);
+               runMSVMTest(true, true, 1, 100, max_iter, 
LopProperties.ExecType.CP);
        }
+
        @Test
        public void testMSVMSparseLambda100Spark() {
-               runMSVMTest(true,true, num_classes, 1, 100, max_iter, 
LopProperties.ExecType.SPARK);
+               runMSVMTest(true, true, 1, 100, max_iter, 
LopProperties.ExecType.SPARK);
        }
+
        @Test
        public void testMSVMIteration() {
-               runMSVMTest(true,true, num_classes, 1, 2.0, 100, 
LopProperties.ExecType.CP);
+               runMSVMTest(true, true, 1, 2.0, 100, LopProperties.ExecType.CP);
        }
+
        @Test
        public void testMSVMDenseIntercept() {
-               runMSVMTest(false,true, num_classes, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
+               runMSVMTest(false, true, eps, 1.0, max_iter, 
LopProperties.ExecType.CP);
        }
-       private void runMSVMTest(boolean sparse, boolean  intercept, int 
classes, double eps,
-                                                         double lambda, int 
run, LopProperties.ExecType instType)
-       {
+
+       private void runMSVMTest(boolean sparse, boolean intercept, double eps, 
double lambda, int run,
+               LopProperties.ExecType instType) {
                Types.ExecMode platformOld = setExecMode(instType);
 
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 
-               try
-               {
+               try {
                        loadTestConfiguration(getTestConfiguration(TEST_NAME));
 
                        double sparsity = sparse ? spSparse : spDense;
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{ "-explain", "-stats",
-                               "-nvargs", "X=" + input("X"), "Y=" + 
input("Y"), "model=" + output("model"),
-                               "inc=" + 
String.valueOf(intercept).toUpperCase(), "num_classes=" + classes, "eps=" + 
eps, "lam=" + lambda, "max=" + run};
+                       programArgs = new String[] {"-nvargs", "X=" + 
input("X"), "Y=" + input("Y"), "model=" + output("model"),
+                               "inc=" + 
String.valueOf(intercept).toUpperCase(), "eps=" + eps, "lam=" + lambda, "max=" 
+ run};
 
                        fullRScriptName = HOME + TEST_NAME + ".R";
-                       rCmd = getRCmd(inputDir(), Boolean.toString(intercept), 
Integer.toString(classes),  Double.toString(eps),
-                               Double.toString(lambda), Integer.toString(run), 
expectedDir());
+                       rCmd = getRCmd(inputDir(),
+                               Boolean.toString(intercept),
+                               Double.toString(eps),
+                               Double.toString(lambda),
+                               Integer.toString(run),
+                               expectedDir());
 
                        double[][] X = getRandomMatrix(rows, colsX, 0, 1, 
sparsity, -1);
-                       double[][] Y = getRandomMatrix(rows, 1, 0, num_classes, 
1, -1);
+                       double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1, -1);
                        Y = TestUtils.round(Y);
 
                        writeInputMatrixWithMTD("X", X, true);
@@ -118,7 +124,7 @@ public class BuiltinMulticlassSVMTest extends 
AutomatedTestBase
                        runRScript(true);
 
                        HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("model");
-                       HashMap<MatrixValue.CellIndex, Double> rfile  = 
readRMatrixFromFS("model");
+                       HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromFS("model");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                }
                finally {
diff --git a/src/test/scripts/functions/builtin/l2svm.dml 
b/src/test/scripts/functions/builtin/confusionMatrix.dml
similarity index 87%
copy from src/test/scripts/functions/builtin/l2svm.dml
copy to src/test/scripts/functions/builtin/confusionMatrix.dml
index fb0a786..cf9a48f 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/confusionMatrix.dml
@@ -19,7 +19,9 @@
 #
 #-------------------------------------------------------------
 
-X = read($X)
 Y = read($Y)
-model= l2svm(X=X,  Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam, 
maxiterations = $max )
-write(model, $model)
+P = read($P)
+
+[confusionCount, confusionAVG] = confusionMatrix(Y=Y, P=P)
+
+write(confusionAVG, $out_file)
diff --git a/src/test/scripts/functions/builtin/l2svm.dml 
b/src/test/scripts/functions/builtin/l2svm.dml
index fb0a786..9b9502d 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/l2svm.dml
@@ -21,5 +21,5 @@
 
 X = read($X)
 Y = read($Y)
-model= l2svm(X=X,  Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam, 
maxiterations = $max )
+model= l2svm(X=X,  Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam, 
maxIterations = $max )
 write(model, $model)
diff --git a/src/test/scripts/functions/builtin/multisvm.R 
b/src/test/scripts/functions/builtin/multisvm.R
index 1258596..59f5e5f 100644
--- a/src/test/scripts/functions/builtin/multisvm.R
+++ b/src/test/scripts/functions/builtin/multisvm.R
@@ -31,10 +31,10 @@ if(check_X == 0){
 }else{
        Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
        intercept = as.logical(args[2])
-       num_classes = as.integer(args[3])
-       epsilon = as.double(args[4])
-       lambda = as.double(args[5])
-       max_iterations = as.integer(args[6])
+       num_classes = max(Y)
+       epsilon = as.double(args[3])
+       lambda = as.double(args[4])
+       max_iterations = as.integer(args[5])
  
        num_samples = nrow(X)
        num_features = ncol(X)
@@ -115,5 +115,5 @@ if(check_X == 0){
        }
        #print("R model "); print(w)
        
-       writeMM(as(w, "CsparseMatrix"), paste(args[7], "model", sep=""))
+       writeMM(as(w, "CsparseMatrix"), paste(args[6], "model", sep=""))
 }
diff --git a/src/test/scripts/functions/builtin/multisvm.dml 
b/src/test/scripts/functions/builtin/multisvm.dml
index 7a76df65..b95b56f 100644
--- a/src/test/scripts/functions/builtin/multisvm.dml
+++ b/src/test/scripts/functions/builtin/multisvm.dml
@@ -21,6 +21,6 @@
 
 X = read($X)
 Y = read($Y)
-model = msvm(X=X,  Y=Y,  intercept = $inc, num_classes= $num_classes,
-  epsilon = $eps, lambda = $lam, max_iterations = $max )
+model = msvm(X=X,  Y=Y,  intercept = $inc,
+  epsilon = $eps, lambda = $lam, maxIterations = $max )
 write(model, $model)
diff --git a/src/test/scripts/functions/builtin/l2svm.dml 
b/src/test/scripts/functions/builtin/multisvmPredict.dml
similarity index 87%
copy from src/test/scripts/functions/builtin/l2svm.dml
copy to src/test/scripts/functions/builtin/multisvmPredict.dml
index fb0a786..6bfcc81 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/multisvmPredict.dml
@@ -20,6 +20,7 @@
 #-------------------------------------------------------------
 
 X = read($X)
-Y = read($Y)
-model= l2svm(X=X,  Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam, 
maxiterations = $max )
-write(model, $model)
+W = read($W)
+[YRaw, Y] = msvmPredict(X=X,  W=W)
+write(YRaw, $YRaw)
+write(Y, $Y)
diff --git a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml 
b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
index 98ccbda..2ee4614 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
@@ -22,5 +22,5 @@
 X = federated(addresses=list($1, $2),
     ranges=list(list(0, 0), list($5, $4), list($5, 0), list($3, $4)))
 Y = read($6)
-model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxiterations = 100)
+model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100)
 write(model, $7)
diff --git 
a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
index c3ac1ef..0b028d8 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
@@ -21,5 +21,5 @@
 
 X = rbind(read($1), read($2))
 Y = read($3)
-model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxiterations = 100)
+model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100)
 write(model, $4)

Reply via email to