This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 8e4f7d8 [SYSTEMDS-395] Cleanup SVM scripts, new confusionMatrix,
msvmPredict
8e4f7d8 is described below
commit 8e4f7d82e0df9ce5c0634d7516de10fb262603ed
Author: Sebastian <[email protected]>
AuthorDate: Sun May 24 00:02:36 2020 +0200
[SYSTEMDS-395] Cleanup SVM scripts, new confusionMatrix, msvmPredict
- ConfusionMatrix
- msvmPredict
Make confusion matrixes, based on Predictions and labels.
It returns two matrixes:
- A count matrix, containing counts of each case in the matrix.
- An avg matrix, returning the accuracy of each class, and thereby how
the percentage distribution across labels, aka the percentage
confusion.
msvmPredict applies the trained msvm model and returns
- Y_hat, the output from the model, that is the raw output from the
model
- Y, the row max of the raw output, which is the highest value
predictions.
Furthermore some consistency changes in L2SVM and MSVM.
Closes #910.
---
docs/Tasks.txt | 1 +
scripts/algorithms/l2-svm.dml | 114 ++----------
scripts/builtin/confusionMatrix.dml | 62 +++++++
scripts/builtin/kmeans.dml | 22 +--
scripts/builtin/l2svm.dml | 97 +++++-----
scripts/builtin/msvm.dml | 42 ++---
scripts/builtin/msvmPredict.dml | 53 ++++++
scripts/builtin/multiLogRegPredict.dml | 17 +-
.../java/org/apache/sysds/common/Builtins.java | 2 +
.../builtin/BuiltinConfusionMatrixTest.java | 195 +++++++++++++++++++++
.../builtin/BuiltinMulticlassSVMPredictTest.java | 186 ++++++++++++++++++++
.../builtin/BuiltinMulticlassSVMTest.java | 54 +++---
.../builtin/{l2svm.dml => confusionMatrix.dml} | 8 +-
src/test/scripts/functions/builtin/l2svm.dml | 2 +-
src/test/scripts/functions/builtin/multisvm.R | 10 +-
src/test/scripts/functions/builtin/multisvm.dml | 4 +-
.../builtin/{l2svm.dml => multisvmPredict.dml} | 7 +-
.../functions/federated/FederatedL2SVMTest.dml | 2 +-
.../federated/FederatedL2SVMTestReference.dml | 2 +-
19 files changed, 653 insertions(+), 227 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 1196566..3c9782f 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -304,6 +304,7 @@ SYSTEMDS-390 New Builtin Functions IV
* 392 Builtin function for missing value imputation via FDs OK
* 393 Builtin to find Connected Components of a graph OK
* 394 Builtin for one-hot encoding of matrix (not frame), see table OK
+ * 395 SVM rework and utils (confusionMatrix, msvmPredict) OK
Others:
* Break append instruction to cbind and rbind
diff --git a/scripts/algorithms/l2-svm.dml b/scripts/algorithms/l2-svm.dml
index 4cbcdb5..04d6524 100644
--- a/scripts/algorithms/l2-svm.dml
+++ b/scripts/algorithms/l2-svm.dml
@@ -21,10 +21,6 @@
# Implements binary-class SVM with squared slack variables
#
-# Example Usage:
-# Assume L2SVM_HOME is set to the home of the dml script
-# Assume input and output directories are on hdfs as INPUT_DIR and OUTPUT_DIR
-# Assume epsilon = 0.001, lambda = 1, maxiterations = 100
#
# INPUT PARAMETERS:
#
---------------------------------------------------------------------------------------------
@@ -40,111 +36,31 @@
# maxiter Int 100 Maximum number of conjugate gradient iterations
# model String --- Location to write model
# fmt String "text" The output format of the output, such as
"text" or "csv"
-# Log String --- [OPTIONAL] Location to write the log file
#
---------------------------------------------------------------------------------------------
-# hadoop jar SystemDS.jar -f $L2SVM_HOME/l2-svm.dml -nvargs X=$INPUT_DIR/X
Y=$INPUT_DIR/Y \
-# icpt=0 tol=0.001 reg=1 maxiter=100 model=$OUPUT_DIR/w Log=$OUTPUT_DIR/Log
fmt="text"
-#
+# Example Execution:
+# systemds -f $SYSTEMDS_ROOT/scripts/algorithms/l2-svm.dml \
+# -nvargs X=$INPUT_DIR/X Y=$INPUT_DIR/Y \
+# icpt=FALSE tol=0.001 reg=1 maxiter=100 \
+# model=$OUPUT_DIR/w fmt="text"
+
# Note about inputs:
# Assumes that labels (entries in Y) are set to either -1 or +1 or
non-negative integers
fmt = ifdef($fmt, "text")
-intercept = ifdef($icpt, 0)
+intercept = ifdef($icpt, FALSE)
epsilon = ifdef($tol, 0.001)
lambda = ifdef($reg, 1.0)
-maxiterations = ifdef($maxiter, 100)
+maxIterations = ifdef($maxiter, 100)
+verbose = ifdef($verbose, FALSE)
X = read($X)
Y = read($Y)
-#check input parameter assertions
-if(nrow(X) < 2)
- stop("Stopping due to invalid inputs: Not possible to learn a binary class
classifier without at least 2 rows")
-if(intercept != 0 & intercept != 1)
- stop("Stopping due to invalid argument: Currently supported intercept
options are 0 and 1")
-if(epsilon < 0)
- stop("Stopping due to invalid argument: Tolerance (tol) must be
non-negative")
-if(lambda < 0)
- stop("Stopping due to invalid argument: Regularization constant (reg) must
be non-negative")
-if(maxiterations < 1)
- stop("Stopping due to invalid argument: Maximum iterations should be a
positive integer")
-
-#check input lables and transform into -1/1
-check_min = min(Y)
-check_max = max(Y)
-num_min = sum(Y == check_min)
-num_max = sum(Y == check_max)
-if(check_min == check_max)
- stop("Stopping due to invalid inputs: Y seems to contain exactly one label")
-if(num_min + num_max != nrow(Y))
- stop("Stopping due to invalid inputs: Y seems to contain more than 2 labels")
-if(check_min != -1 | check_max != 1)
- Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max -
check_min)
-
-positive_label = check_max
-negative_label = check_min
-num_samples = nrow(X)
-dimensions = ncol(X)
-num_rows_in_w = dimensions
-
-if (intercept == 1) {
- ones = matrix(1, rows=num_samples, cols=1)
- X = cbind(X, ones);
- num_rows_in_w += 1
-}
-
-w = matrix(0, num_rows_in_w, 1)
-Xw = matrix(0, rows=nrow(X), cols=1)
-g_old = t(X) %*% Y
-s = g_old
-
-debug_str = "# Iter, Obj"
-iter = 0
-continue = TRUE
-
-while(continue & iter < maxiterations) {
- # minimizing primal obj along direction s
- step_sz = 0
- Xd = X %*% s
- wd = lambda * sum(w * s)
- dd = lambda * sum(s * s)
-
- continue1 = TRUE
- while(continue1) {
- tmp_Xw = Xw + step_sz*Xd
- out = 1 - Y * tmp_Xw
- sv = out > 0
- out = out * sv
- g = wd + step_sz*dd - sum(out * Y * Xd)
- h = dd + sum(Xd * sv * Xd)
- step_sz = step_sz - g/h
- continue1 = (g*g/h >= 0.0000000001);
- }
-
- #update weights
- w += step_sz * s
- Xw += step_sz * Xd
-
- out = 1 - Y * Xw
- sv = out > 0
- out = sv * out
- obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
- g_new = t(X) %*% (out * Y) - lambda * w
-
- print("ITER " + iter + ": OBJ=" + obj)
- debug_str = append(debug_str, iter + "," + obj)
-
- tmp = sum(s * g_old)
-
- #non-linear CG step
- be = sum(g_new * g_new)/sum(g_old * g_old)
- s = be * s + g_new
- g_old = g_new
-
- continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
- iter = iter + 1
-}
+w = l2svm(X=X, Y=Y, intercept=intercept,
+ epsilon=epsilon, lambda=labmda,
+ maxIterations=maxIterations,
+ verbose=verbose)
extra_model_params = matrix(0, 4, 1)
extra_model_params[1,1] = positive_label
@@ -154,7 +70,3 @@ extra_model_params[4,1] = dimensions
w = rbind(w, extra_model_params)
write(w, $model, format=fmt)
-
-logFile = $Log
-if(logFile != " ")
- write(debug_str, logFile)
diff --git a/scripts/builtin/confusionMatrix.dml
b/scripts/builtin/confusionMatrix.dml
new file mode 100644
index 0000000..0a62182
--- /dev/null
+++ b/scripts/builtin/confusionMatrix.dml
@@ -0,0 +1,62 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# INPUT PARAMETERS:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
---------------------------------------------------------------------------------------------
+# P Double --- vector of Predictions
+# Y Double --- vector of Golden standard One Hot Encoded
+#
---------------------------------------------------------------------------------------------
+# OUTPUT:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
---------------------------------------------------------------------------------------------
+# ConfusionSum Double --- The Confusion Matrix Sums of
classifications
+# ConfusionAvg Double --- The Confusion Matrix averages of each
true class
+
+# Output is like:
+# True Labels
+# 1 2
+# 1 TP | FP
+# Predictions ----+----
+# 2 FN | TN
+#
+# TP = True Positives
+# FP = False Positives
+# FN = False Negatives
+# TN = True Negatives
+
+m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y)
+ return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg)
+{
+ if(ncol(P) > 1 | ncol(Y) > 1)
+ stop("CONFUSION MATRIX: Invalid input number of cols should be 1 in both P
["+ncol(P)+"] and Y ["+ncol(Y)+"]")
+ if(nrow(P) != nrow(Y))
+ stop("CONFUSION MATRIX: The number of rows have to be equal in both P
["+nrow(P)+"] and Y ["+nrow(Y)+"]")
+ if(min(P) < 1 | min(Y) < 1)
+ stop("CONFUSION MATRIX: All Values in P and Y should be abore or equal to
1, min(P):" + min(P) + " min(Y):" + min(Y) )
+
+ dim = max(max(Y),max(P))
+ confusionSum = table(P, Y, dim, dim)
+ # max to avoid devision by 0, in case a colum contain no entries.
+ confusionAvg = confusionSum / max(1,colSums(confusionSum))
+}
diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 96591c6..75e5bc7 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -23,21 +23,23 @@
#
# INPUT PARAMETERS:
# ----------------------------------------------------------------------------
-# NAME TYPE DEFAULT MEANING
+# NAME TYPE DEFAULT MEANING
# ----------------------------------------------------------------------------
-# X String --- Location to read matrix X
with the input data records
-# k Int --- Number of centroids
-# runs Int 10 Number of runs (with
different initial centroids)
-# max_iter Int 1000 Maximum number of
iterations per run
-# eps Double 0.000001 Tolerance (epsilon) for
WCSS change ratio
-# is_verbose Boolean FALSE do not print per-iteration
stats
-# avg_sample_size_per_centroid Int 50 Average number of records
per centroid in data samples
+# X Double --- The input Matrix to do
KMeans on.
+# k Int --- Number of centroids
+# runs Int 10 Number of runs (with
different initial centroids)
+# max_iter Int 1000 Maximum number of
iterations per run
+# eps Double 0.000001 Tolerance (epsilon) for
WCSS change ratio
+# is_verbose Boolean FALSE do not print
per-iteration stats
+# avg_sample_size_per_centroid Int 50 Average number of
records per centroid in data samples
#
#
# RETURN VALUES
# ----------------------------------------------------------------------------
-# Y String "Y.mtx" Location to store the mapping of records to centroids
-# C String "C.mtx" Location to store the output matrix with the centroids
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# Y String "Y.mtx" The mapping of records to centroids
+# C String "C.mtx" The output matrix with the centroids
# ----------------------------------------------------------------------------
diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index fac0be9..3e251ae 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -27,12 +27,15 @@
# NAME TYPE DEFAULT MEANING
#
---------------------------------------------------------------------------------------------
# X Double --- matrix X of feature vectors
-# Y Double --- matrix Y of class labels
+# Y Double --- matrix Y of class labels have to be a
single column
# intercept Boolean False No Intercept ( If set to TRUE then a
constant bias column is added to X)
# epsilon Double 0.001 Procedure terminates early if the
reduction in objective function
# value is less than epsilon (tolerance) times the
initial objective function value.
# lambda Double 1.0 Regularization parameter (lambda) for L2
regularization
# maxiterations Int 100 Maximum number of conjugate gradient
iterations
+# verbose Boolean False Set to true if one wants print
statements updating on loss.
+# column_id Int -1 The column Id used if one wants to add a
ID to the print statement,
+# Specificly usefull when L2SVM is used in
MSVM.
#
---------------------------------------------------------------------------------------------
@@ -44,65 +47,70 @@
m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept =
FALSE,
-Double epsilon = 0.001, Double lambda = 1, Integer maxiterations = 100,
Boolean verbose = FALSE)
- return(Matrix[Double] model)
+ Double epsilon = 0.001, Double lambda = 1, Integer maxIterations = 100,
+ Boolean verbose = FALSE, Integer columnId = -1)
+ return(Matrix[Double] model)
{
-
#check input parameter assertions
if(nrow(X) < 2)
- stop("Stopping due to invalid inputs: Not possible to learn a binary class
classifier without at least 2 rows")
+ stop("L2SVM: Stopping due to invalid inputs: Not possible to learn a
binary class classifier without at least 2 rows")
if(epsilon < 0)
- stop("Stopping due to invalid argument: Tolerance (tol) must be
non-negative")
+ stop("L2SVM: Stopping due to invalid argument: Tolerance (tol) must be
non-negative")
if(lambda < 0)
- stop("Stopping due to invalid argument: Regularization constant (reg) must
be non-negative")
- if(maxiterations < 1)
- stop("Stopping due to invalid argument: Maximum iterations should be a
positive integer")
-
+ stop("L2SVM: Stopping due to invalid argument: Regularization constant
(reg) must be non-negative")
+ if(maxIterations < 1)
+ stop("L2SVM: Stopping due to invalid argument: Maximum iterations should
be a positive integer")
+ if(ncol(Y) < 1)
+ stop("L2SVM: Stopping due to invalid multiple label columns, maybe use
MSVM instead?")
+
#check input lables and transform into -1/1
check_min = min(Y)
check_max = max(Y)
-
+
num_min = sum(Y == check_min)
num_max = sum(Y == check_max)
-
-
- if(num_min + num_max != nrow(Y)) print("please check Y, it should contain
only 2 labels")
- else{
- if(check_min != -1 | check_max != +1)
- Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max -
check_min)
- }
-
- if(verbose) print('running L2-SVM ');
-
+
+ # TODO make this a stop condition for l2svm instead of just printing.
+ if(num_min + num_max != nrow(Y))
+ print("L2SVM: WARNING invalid number of labels in Y")
+
+ # Scale inputs to -1 for negative, and 1 for positive classification
+ if(check_min != -1 | check_max != +1)
+ Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max -
check_min)
+
+ # If column_id is -1 then we assume that the fundamental algorithm is MSVM,
+ # Therefore don't print message.
+ if(verbose & columnId == -1)
+ print('Running L2-SVM ')
+
num_samples = nrow(X)
- dimensions = ncol(X)
+ num_classes = ncol(Y)
+ # Add Bias
+ num_rows_in_w = ncol(X)
if (intercept) {
ones = matrix(1, rows=num_samples, cols=1)
X = cbind(X, ones);
+ num_rows_in_w += 1
}
- num_rows_in_w = dimensions
- if(intercept){
- num_rows_in_w = num_rows_in_w + 1
- }
w = matrix(0, rows=num_rows_in_w, cols=1)
-
+
g_old = t(X) %*% Y
s = g_old
-
+
Xw = matrix(0, rows=nrow(X), cols=1)
-
+
iter = 0
- continue = 1
- while(continue == 1 & iter < maxiterations) {
+ continue = TRUE
+ while(continue & iter < maxIterations) {
# minimizing primal obj along direction s
step_sz = 0
Xd = X %*% s
wd = lambda * sum(w * s)
dd = lambda * sum(s * s)
- continue1 = 1
- while(continue1 == 1){
+ continue1 = TRUE
+ while(continue1){
tmp_Xw = Xw + step_sz*Xd
out = 1 - Y * (tmp_Xw)
sv = (out > 0)
@@ -112,32 +120,31 @@ Double epsilon = 0.001, Double lambda = 1, Integer
maxiterations = 100, Boolean
step_sz = step_sz - g/h
continue1 = (g*g/h >= epsilon)
}
-
+
#update weights
w = w + step_sz*s
Xw = Xw + step_sz*Xd
-
+
out = 1 - Y * Xw
sv = (out > 0)
out = sv * out
obj = 0.5 * sum(out * out) + lambda/2 * sum(w * w)
g_new = t(X) %*% (out * Y) - lambda * w
-
-
- if(verbose) print("Iter, Obj "+ iter + ", "+obj)
-
- tmp = sum(s * g_old)
- if(step_sz*tmp < epsilon*obj){
- continue = 0
+
+ if(verbose) {
+ colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
+ print("Iter:" + toString(iter) + colstr + " Obj:" + obj)
}
-
+
+ tmp = sum(s * g_old)
+ continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
+
#non-linear CG step
be = sum(g_new * g_new)/sum(g_old * g_old)
s = be * s + g_new
g_old = g_new
-
+
iter = iter + 1
}
-
model = w
}
diff --git a/scripts/builtin/msvm.dml b/scripts/builtin/msvm.dml
index 86046c6..63dfa5e 100644
--- a/scripts/builtin/msvm.dml
+++ b/scripts/builtin/msvm.dml
@@ -33,7 +33,8 @@
# epsilon Double 0.001 Procedure terminates early if the
reduction in objective function
# value is less than epsilon (tolerance)
times the initial objective function value.
# lambda Double 1.0 Regularization parameter (lambda) for L2
regularization
-# maxiterations Int 100 Maximum number of conjugate gradient
iterations
+# maxIterations Int 100 Maximum number of conjugate gradient
iterations
+# verbose Boolean False Set to true to print while training.
#
---------------------------------------------------------------------------------------------
#Output(s)
@@ -42,34 +43,33 @@
#
---------------------------------------------------------------------------------------------
# model Double --- model matrix
-m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept =
FALSE, Integer num_classes =10,
- Double epsilon = 0.001, Double lambda = 1.0, Integer
max_iterations = 100, Boolean verbose = FALSE)
+m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept =
FALSE,
+ Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100,
Boolean verbose = FALSE)
return(Matrix[Double] model)
{
- if(verbose)
- print("Built-in Multiclass-SVM started")
+ if(min(Y) < 0)
+ stop("MSVM: Invalid Y input, containing negative values")
- num_samples = nrow(X)
- num_features = ncol(X)
+ if(verbose)
+ print("Running Multiclass-SVM")
- num_rows_in_w = num_features
+ num_rows_in_w = ncol(X)
if(intercept) {
num_rows_in_w = num_rows_in_w + 1
}
-
- w = matrix(0, rows=num_rows_in_w, cols=num_classes)
- parfor(iter_class in 1:num_classes) {
- Y_local = 2 * (Y == iter_class) - 1
- if(verbose) {
- print("iter class: " + iter_class)
- print("y local: " + toString(Y_local))
- }
- w[,iter_class] = l2svm(X=X, Y=Y_local, intercept=intercept,
- epsilon=epsilon, lambda=lambda, maxiterations=max_iterations)
- }
+ if(ncol(Y) > 1)
+ Y = rowMaxs(Y * t(seq(1,ncol(Y))))
+ # Assuming number of classes to be max contained in Y
+ w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+
+ parfor(class in 1:max(Y)) {
+ Y_local = 2 * (Y == class) - 1
+ w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
+ epsilon=epsilon, lambda=lambda, maxIterations=maxIterations,
+ verbose= verbose, columnId=class)
+ }
+
model = w
- if (verbose)
- print("model["+iter_class+"]: " + toString(model))
}
diff --git a/scripts/builtin/msvmPredict.dml b/scripts/builtin/msvmPredict.dml
new file mode 100644
index 0000000..2b4fa42
--- /dev/null
+++ b/scripts/builtin/msvmPredict.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This Scripts helps in applying an trained MSVM
+
+# INPUT PARAMETERS:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
---------------------------------------------------------------------------------------------
+# X Double --- matrix X of feature vectors to classify
+# W Double --- matrix of the trained variables
+#
---------------------------------------------------------------------------------------------
+# OUTPUT:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+#
---------------------------------------------------------------------------------------------
+# Y^ Double --- Classification Labels Raw, meaning not
modified to clean
+# Labeles of 1's and -1's
+# Y Double --- Classification Labels Maxed to ones and
zeros.
+
+m_msvmPredict = function(Matrix[Double] X, Matrix[Double] W)
+ return(Matrix[Double] YRaw, Matrix[Double] Y)
+{
+ if(ncol(X) != nrow(W)){
+ if(ncol(X) + 1 != nrow(W)){
+ stop("MSVM Predict: Invalid shape of W ["+ncol(W)+","+nrow(W)+"] or X
["+ncol(X)+","+nrow(X)+"]")
+ }
+ YRaw = X %*% W[1:ncol(X),] + W[ncol(X)+1,]
+ Y = rowIndexMax(YRaw)
+ }
+ else{
+ YRaw = X %*% W
+ Y = rowIndexMax(YRaw)
+ }
+}
diff --git a/scripts/builtin/multiLogRegPredict.dml
b/scripts/builtin/multiLogRegPredict.dml
index 5ea9a04..213f7dd 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -47,17 +47,17 @@ return(Matrix[Double] M, Matrix[Double] predicted_Y, Double
accuracy)
if(min(Y) <= 0)
stop("class labels should be greater than zero")
- num_records = nrow (X);
- num_features = ncol (X);
- beta = B [1 : ncol (X), ];
- intercept = B [nrow(B), ];
+ num_records = nrow(X);
+ num_features = ncol(X);
+ beta = B[1:ncol(X), ];
+ intercept = B[nrow(B), ];
- if (nrow (B) == ncol (X))
+ if (nrow(B) == ncol(X))
intercept = 0.0 * intercept;
else
num_features = num_features + 1;
- ones_rec = matrix (1, rows = num_records, cols = 1);
+ ones_rec = matrix(1, rows = num_records, cols = 1);
linear_terms = X %*% beta + ones_rec %*% intercept;
M = probabilities(linear_terms); # compute the probablitites on unknown data
@@ -67,10 +67,7 @@ return(Matrix[Double] M, Matrix[Double] predicted_Y, Double
accuracy)
accuracy = sum((predicted_Y - Y) == 0) / num_records * 100;
if(verbose)
- {
- acc_str = "Accuracy (%): " + accuracy
- print(acc_str)
- }
+ print("Accuracy (%): " + accuracy);
}
probabilities = function (Matrix[double] linear_terms)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index b09fc63..7345077 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -81,6 +81,7 @@ public enum Builtins {
CUMPROD("cumprod", false),
CUMSUM("cumsum", false),
CUMSUMPROD("cumsumprod", false),
+ CONFUSIONMATRIX("confusionMatrix", true),
DETECTSCHEMA("detectSchema", false),
DIAG("diag", false),
DISCOVER_FD("discoverFD", true),
@@ -127,6 +128,7 @@ public enum Builtins {
MEDIAN("median", false),
MOMENT("moment", "centralMoment", false),
MSVM("msvm", true),
+ MSVMPREDICT("msvmPredict", true),
MULTILOGREG("multiLogReg", true),
MULTILOGREGPREDICT("multiLogRegPredict", true),
NCOL("ncol", false),
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
new file mode 100644
index 0000000..bd5ccca
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinConfusionMatrixTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import static org.junit.Assert.fail;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinConfusionMatrixTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "confusionMatrix";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinConfusionMatrixTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+ }
+
+ public double eps = 0.00001;
+
+ @Test
+ public void test_01() {
+ double[][] y;
+ double[][] p;
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+
+ // Classification is 100% accurate in all classes if Y == P.
+ y = TestUtils.round(getRandomMatrix(1000, 1, 1, 2, 1.0, 7));
+ p = y;
+
+ res.put(new CellIndex(1, 1), 1.0);
+ res.put(new CellIndex(2, 2), 1.0);
+
+ for(LopProperties.ExecType ex : new ExecType[]
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+ runConfusionMatrixTest(y, p, res, ex);
+ }
+ }
+
+ @Test
+ public void test_02() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ res.put(new CellIndex(2, 2), 1.0);
+ runConfusionMatrixTest(new double[][] {{2}}, new double[][]
{{2}}, res, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void test_03() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ res.put(new CellIndex(2, 1), 1.0);
+ runConfusionMatrixTest(new double[][] {{1}}, new double[][]
{{2}}, res, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void test_04() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ res.put(new CellIndex(6, 1), 1.0);
+ runConfusionMatrixTest(new double[][] {{1}}, new double[][]
{{6}}, res, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void test_05() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ res.put(new CellIndex(1, 9), 1.0);
+ runConfusionMatrixTest(new double[][] {{9}}, new double[][]
{{1}}, res, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void test_06() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ double[][] y = new double[][] {{1}, {1}, {1}, {1}};
+ double[][] p = new double[][] {{1}, {2}, {3}, {4}};
+ res.put(new CellIndex(1, 1), 0.25);
+ res.put(new CellIndex(2, 1), 0.25);
+ res.put(new CellIndex(3, 1), 0.25);
+ res.put(new CellIndex(4, 1), 0.25);
+ runConfusionMatrixTest(y, p, res, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void test_07() {
+ HashMap<MatrixValue.CellIndex, Double> res = new HashMap<>();
+ double[][] y = new double[][] {{1}, {2}, {3}, {4}};
+ double[][] p = new double[][] {{1}, {1}, {1}, {1}};
+ res.put(new CellIndex(1, 1), 1.0);
+ res.put(new CellIndex(1, 2), 1.0);
+ res.put(new CellIndex(1, 3), 1.0);
+ res.put(new CellIndex(1, 4), 1.0);
+ runConfusionMatrixTest(y, p, res, LopProperties.ExecType.CP);
+ }
+
+ private void runConfusionMatrixTest(double[][] y, double[][] p,
HashMap<MatrixValue.CellIndex, Double> res,
+ ExecType instType) {
+ ExecMode platformOld = setExecMode(instType);
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "P=" +
input("P"), "Y=" + input("Y"), "out_file=" + output("B")};
+ writeInputMatrixWithMTD("P", p, false);
+ writeInputMatrixWithMTD("Y", y, false);
+ runTest(true, false, null, -1);
+
+ HashMap<MatrixValue.CellIndex, Double> dmlResult =
readDMLMatrixFromHDFS("B");
+ TestUtils.compareMatrices(dmlResult, res, eps,
"DML_Result", "Expected");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ // TODO Future, does it make sense to save an empty matrix, since we
have ways to make an empty matrix?
+ // @Test
+ // public void test_invalid_01(){
+ // // Test if the script fails with input containing no values.
+ // runConfusionMatrixExceptionTest(new double[][]{}, new double[][]{});
+ // }
+
+ @Test
+ public void test_invalid_02() {
+ // Test if the script fails with input contain multiple columns
+ runConfusionMatrixExceptionTest(new double[][] {{1, 2}}, new
double[][] {{1, 2}});
+ runConfusionMatrixExceptionTest(new double[][] {{1}}, new
double[][] {{1, 2}});
+ runConfusionMatrixExceptionTest(new double[][] {{1, 2}}, new
double[][] {{1}});
+ }
+
+ @Test
+ public void test_invalid_03() {
+ // Test if the script fails with input contains different
amount of rows
+ runConfusionMatrixExceptionTest(new double[][] {{1}, {1}}, new
double[][] {{1}});
+ runConfusionMatrixExceptionTest(new double[][] {{1}}, new
double[][] {{1}, {1}});
+ }
+
+ private void runConfusionMatrixExceptionTest(double[][] y, double[][]
p) {
+ ExecMode platformOld = setExecMode(LopProperties.ExecType.CP);
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "P=" +
input("P"), "Y=" + input("Y"), "out_file=" + output("B")};
+ writeInputMatrixWithMTD("P", p, false);
+ writeInputMatrixWithMTD("Y", y, false);
+
+ // TODO make stop throw exception instead
+ //
https://issues.apache.org/jira/projects/SYSTEMML/issues/SYSTEMML-2540
+ // runTest(true, true, DMLScriptException.class, -1);
+
+ // Verify that the outputfile is not existing!
+ runTest(true, false, null, -1);
+
+ try {
+ readDMLMatrixFromHDFS("B");
+ fail("File should not have been written");
+ }
+ catch(AssertionError e) {
+ // exception expected
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
new file mode 100644
index 0000000..985771d
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMPredictTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import static org.junit.Assert.fail;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BuiltinMulticlassSVMPredictTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "multisvmPredict";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinConfusionMatrixTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"YRaw", "Y"}));
+ }
+
+ public double eps = 0.00001;
+
+ @Test
+ public void test_01() {
+
+ double[][] x = new double[][] {{0.4, 0.0}, {0.0, 0.2}};
+ double[][] w = new double[][] {{1.0, 0.0}, {0.0, 1.0}};
+
+ HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+ res_Y.put(new CellIndex(1, 1), 1.0);
+ res_Y.put(new CellIndex(2, 1), 2.0);
+
+ HashMap<MatrixValue.CellIndex, Double> res_YRaw = new
HashMap<>();
+ res_YRaw.put(new CellIndex(1, 1), 0.4);
+ res_YRaw.put(new CellIndex(2, 2), 0.2);
+
+ for(LopProperties.ExecType ex : new ExecType[]
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+ runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+ }
+ }
+
+ @Test
+ public void test_02() {
+ double[][] x = new double[][] {{0.4, 0.1}};
+ double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0}};
+
+ HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+ res_Y.put(new CellIndex(1, 1), 1.0);
+
+ HashMap<MatrixValue.CellIndex, Double> res_YRaw = new
HashMap<>();
+ res_YRaw.put(new CellIndex(1, 1), 0.42);
+ res_YRaw.put(new CellIndex(1, 2), 0.3);
+
+ for(LopProperties.ExecType ex : new ExecType[]
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+ runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+ }
+ }
+
+ @Test
+ public void test_03() {
+ // Add bios column
+ double[][] x = new double[][] {{0.4, 0.1}};
+ double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0},
{1.0,0.5}};
+
+ HashMap<MatrixValue.CellIndex, Double> res_Y = new HashMap<>();
+ res_Y.put(new CellIndex(1, 1), 1.0);
+
+ HashMap<MatrixValue.CellIndex, Double> res_YRaw = new
HashMap<>();
+ res_YRaw.put(new CellIndex(1, 1), 1.42);
+ res_YRaw.put(new CellIndex(1, 2), 0.8);
+
+ for(LopProperties.ExecType ex : new ExecType[]
{LopProperties.ExecType.CP, LopProperties.ExecType.SPARK}) {
+ runMSVMPredict(x, w, res_YRaw, res_Y, ex);
+ }
+ }
+
+ private void runMSVMPredict(double[][] x, double[][] w,
HashMap<MatrixValue.CellIndex, Double> YRaw,
+ HashMap<MatrixValue.CellIndex, Double> Y, ExecType instType) {
+ ExecMode platformOld = setExecMode(instType);
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "X=" +
input("X"), "W=" + input("W"), "YRaw=" + output("YRaw"),
+ "Y=" + output("Y")};
+ writeInputMatrixWithMTD("X", x, false);
+ writeInputMatrixWithMTD("W", w, false);
+ runTest(true, false, null, -1);
+
+ HashMap<MatrixValue.CellIndex, Double> YRaw_res =
readDMLMatrixFromHDFS("YRaw");
+ HashMap<MatrixValue.CellIndex, Double> Y_res =
readDMLMatrixFromHDFS("Y");
+
+ TestUtils.compareMatrices(YRaw_res, YRaw, eps,
"DML_Result", "Expected");
+ TestUtils.compareMatrices(Y_res, Y, eps, "DML_Result",
"Expected");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ @Test
+ public void test_invalid_01() {
+ // Test if the script fails with input contain incorrect number
of columns
+ double[][] x = new double[][] {{1, 2, 3}};
+ double[][] w = new double[][] {{1, -1}, {-1, 1}};
+ runMSVMPredictionExceptionTest(x, w);
+ }
+
+ @Test
+ public void test_invalid_02() {
+ // Test if the script fails with input contain incorrect number
of rows vs columns
+ double[][] x = new double[][] {{1, 2, 3}};
+ double[][] w = new double[][] {{1, -1, 1, 3 ,3}, {-1, 1, 1, 1
,12}};
+ runMSVMPredictionExceptionTest(x, w);
+ }
+
+ @Test
+ public void test_invalid_03() {
+ // Add one column more than the bios column.
+ double[][] x = new double[][] {{1, 2, 3}};
+ double[][] w = new double[][] {{1.0, 0.5}, {0.2, 1.0},
{1.0,0.5},{1.0,0.5},{1.0,0.5}};
+ runMSVMPredictionExceptionTest(x, w);
+ }
+
+ private void runMSVMPredictionExceptionTest(double[][] x, double[][] w)
{
+ ExecMode platformOld = setExecMode(LopProperties.ExecType.CP);
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "X=" +
input("X"), "W=" + input("W"), "YRaw=" + output("YRaw"),
+ "Y=" + output("Y")};
+ writeInputMatrixWithMTD("X", x, false);
+ writeInputMatrixWithMTD("W", w, false);
+
+ // TODO make stop throw exception instead
+ //
https://issues.apache.org/jira/projects/SYSTEMML/issues/SYSTEMML-2540
+ // runTest(true, true, DMLScriptException.class, -1);
+
+ // Verify that the outputfile is not existing!
+ runTest(true, false, null, -1);
+
+ try {
+ readDMLMatrixFromHDFS("YRaw");
+ fail("File should not have been written");
+ }
+ catch(AssertionError e) {
+ // exception expected
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
index b46ba3f..56ceda1 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMulticlassSVMTest.java
@@ -31,8 +31,7 @@ import org.apache.sysds.test.TestUtils;
import java.util.HashMap;
-public class BuiltinMulticlassSVMTest extends AutomatedTestBase
-{
+public class BuiltinMulticlassSVMTest extends AutomatedTestBase {
private final static String TEST_NAME = "multisvm";
private final static String TEST_DIR = "functions/builtin/";
private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinMulticlassSVMTest.class.getSimpleName() + "/";
@@ -43,72 +42,79 @@ public class BuiltinMulticlassSVMTest extends
AutomatedTestBase
private final static double spSparse = 0.01;
private final static double spDense = 0.7;
private final static int max_iter = 10;
- private final static int num_classes = 10;
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"C"}));
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
}
@Test
public void testMSVMDense() {
- runMSVMTest(false, false, num_classes, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
+ runMSVMTest(false, false, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
}
+
@Test
public void testMSVMSparse() {
- runMSVMTest(true, false, num_classes, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
+ runMSVMTest(true, false, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
}
+
@Test
public void testMSVMInterceptSpark() {
- runMSVMTest(true,true, num_classes, eps, 1.0, max_iter,
LopProperties.ExecType.SPARK);
+ runMSVMTest(true, true, eps, 1.0, max_iter,
LopProperties.ExecType.SPARK);
}
@Test
public void testMSVMSparseLambda2() {
- runMSVMTest(true,true, num_classes, eps,2.0, max_iter,
LopProperties.ExecType.CP);
+ runMSVMTest(true, true, eps, 2.0, max_iter,
LopProperties.ExecType.CP);
}
+
@Test
public void testMSVMSparseLambda100CP() {
- runMSVMTest(true,true, num_classes, 1, 100, max_iter,
LopProperties.ExecType.CP);
+ runMSVMTest(true, true, 1, 100, max_iter,
LopProperties.ExecType.CP);
}
+
@Test
public void testMSVMSparseLambda100Spark() {
- runMSVMTest(true,true, num_classes, 1, 100, max_iter,
LopProperties.ExecType.SPARK);
+ runMSVMTest(true, true, 1, 100, max_iter,
LopProperties.ExecType.SPARK);
}
+
@Test
public void testMSVMIteration() {
- runMSVMTest(true,true, num_classes, 1, 2.0, 100,
LopProperties.ExecType.CP);
+ runMSVMTest(true, true, 1, 2.0, 100, LopProperties.ExecType.CP);
}
+
@Test
public void testMSVMDenseIntercept() {
- runMSVMTest(false,true, num_classes, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
+ runMSVMTest(false, true, eps, 1.0, max_iter,
LopProperties.ExecType.CP);
}
- private void runMSVMTest(boolean sparse, boolean intercept, int
classes, double eps,
- double lambda, int
run, LopProperties.ExecType instType)
- {
+
+ private void runMSVMTest(boolean sparse, boolean intercept, double eps,
double lambda, int run,
+ LopProperties.ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- try
- {
+ try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
double sparsity = sparse ? spSparse : spDense;
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{ "-explain", "-stats",
- "-nvargs", "X=" + input("X"), "Y=" +
input("Y"), "model=" + output("model"),
- "inc=" +
String.valueOf(intercept).toUpperCase(), "num_classes=" + classes, "eps=" +
eps, "lam=" + lambda, "max=" + run};
+ programArgs = new String[] {"-nvargs", "X=" +
input("X"), "Y=" + input("Y"), "model=" + output("model"),
+ "inc=" +
String.valueOf(intercept).toUpperCase(), "eps=" + eps, "lam=" + lambda, "max="
+ run};
fullRScriptName = HOME + TEST_NAME + ".R";
- rCmd = getRCmd(inputDir(), Boolean.toString(intercept),
Integer.toString(classes), Double.toString(eps),
- Double.toString(lambda), Integer.toString(run),
expectedDir());
+ rCmd = getRCmd(inputDir(),
+ Boolean.toString(intercept),
+ Double.toString(eps),
+ Double.toString(lambda),
+ Integer.toString(run),
+ expectedDir());
double[][] X = getRandomMatrix(rows, colsX, 0, 1,
sparsity, -1);
- double[][] Y = getRandomMatrix(rows, 1, 0, num_classes,
1, -1);
+ double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1, -1);
Y = TestUtils.round(Y);
writeInputMatrixWithMTD("X", X, true);
@@ -118,7 +124,7 @@ public class BuiltinMulticlassSVMTest extends
AutomatedTestBase
runRScript(true);
HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromHDFS("model");
- HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromFS("model");
+ HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromFS("model");
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
}
finally {
diff --git a/src/test/scripts/functions/builtin/l2svm.dml
b/src/test/scripts/functions/builtin/confusionMatrix.dml
similarity index 87%
copy from src/test/scripts/functions/builtin/l2svm.dml
copy to src/test/scripts/functions/builtin/confusionMatrix.dml
index fb0a786..cf9a48f 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/confusionMatrix.dml
@@ -19,7 +19,9 @@
#
#-------------------------------------------------------------
-X = read($X)
Y = read($Y)
-model= l2svm(X=X, Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam,
maxiterations = $max )
-write(model, $model)
+P = read($P)
+
+[confusionCount, confusionAVG] = confusionMatrix(Y=Y, P=P)
+
+write(confusionAVG, $out_file)
diff --git a/src/test/scripts/functions/builtin/l2svm.dml
b/src/test/scripts/functions/builtin/l2svm.dml
index fb0a786..9b9502d 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/l2svm.dml
@@ -21,5 +21,5 @@
X = read($X)
Y = read($Y)
-model= l2svm(X=X, Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam,
maxiterations = $max )
+model= l2svm(X=X, Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam,
maxIterations = $max )
write(model, $model)
diff --git a/src/test/scripts/functions/builtin/multisvm.R
b/src/test/scripts/functions/builtin/multisvm.R
index 1258596..59f5e5f 100644
--- a/src/test/scripts/functions/builtin/multisvm.R
+++ b/src/test/scripts/functions/builtin/multisvm.R
@@ -31,10 +31,10 @@ if(check_X == 0){
}else{
Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
intercept = as.logical(args[2])
- num_classes = as.integer(args[3])
- epsilon = as.double(args[4])
- lambda = as.double(args[5])
- max_iterations = as.integer(args[6])
+ num_classes = max(Y)
+ epsilon = as.double(args[3])
+ lambda = as.double(args[4])
+ max_iterations = as.integer(args[5])
num_samples = nrow(X)
num_features = ncol(X)
@@ -115,5 +115,5 @@ if(check_X == 0){
}
#print("R model "); print(w)
- writeMM(as(w, "CsparseMatrix"), paste(args[7], "model", sep=""))
+ writeMM(as(w, "CsparseMatrix"), paste(args[6], "model", sep=""))
}
diff --git a/src/test/scripts/functions/builtin/multisvm.dml
b/src/test/scripts/functions/builtin/multisvm.dml
index 7a76df65..b95b56f 100644
--- a/src/test/scripts/functions/builtin/multisvm.dml
+++ b/src/test/scripts/functions/builtin/multisvm.dml
@@ -21,6 +21,6 @@
X = read($X)
Y = read($Y)
-model = msvm(X=X, Y=Y, intercept = $inc, num_classes= $num_classes,
- epsilon = $eps, lambda = $lam, max_iterations = $max )
+model = msvm(X=X, Y=Y, intercept = $inc,
+ epsilon = $eps, lambda = $lam, maxIterations = $max )
write(model, $model)
diff --git a/src/test/scripts/functions/builtin/l2svm.dml
b/src/test/scripts/functions/builtin/multisvmPredict.dml
similarity index 87%
copy from src/test/scripts/functions/builtin/l2svm.dml
copy to src/test/scripts/functions/builtin/multisvmPredict.dml
index fb0a786..6bfcc81 100644
--- a/src/test/scripts/functions/builtin/l2svm.dml
+++ b/src/test/scripts/functions/builtin/multisvmPredict.dml
@@ -20,6 +20,7 @@
#-------------------------------------------------------------
X = read($X)
-Y = read($Y)
-model= l2svm(X=X, Y=Y, intercept = $inc, epsilon = $eps, lambda = $lam,
maxiterations = $max )
-write(model, $model)
+W = read($W)
+[YRaw, Y] = msvmPredict(X=X, W=W)
+write(YRaw, $YRaw)
+write(Y, $Y)
diff --git a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
index 98ccbda..2ee4614 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
@@ -22,5 +22,5 @@
X = federated(addresses=list($1, $2),
ranges=list(list(0, 0), list($5, $4), list($5, 0), list($3, $4)))
Y = read($6)
-model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1,
maxiterations = 100)
+model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1,
maxIterations = 100)
write(model, $7)
diff --git
a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
index c3ac1ef..0b028d8 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
@@ -21,5 +21,5 @@
X = rbind(read($1), read($2))
Y = read($3)
-model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1,
maxiterations = 100)
+model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1,
maxIterations = 100)
write(model, $4)