This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d573a1f  [SYSTEMDS-3265] Fix gridSearch for multi-class classification
d573a1f is described below

commit d573a1f15e94da053b9b2153eee3b15c4c25fcac
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Dec 31 20:11:54 2021 +0100

    [SYSTEMDS-3265] Fix gridSearch for multi-class classification
    
    This patch generalizes the existing grid search second-order builtin
    function to properly handle multi-column models as used during
    multi-class classification. Models are now temporarily reshaped to
    vectors (which is a no-op for dense models) and returned in
    linearized form. The caller can then reshape it back with knowledge
    of the number of classes and use it.
---
 scripts/builtin/gridSearch.dml                     | 13 ++++++-----
 .../builtin/part1/BuiltinGridSearchTest.java       | 26 +++++++++++++++++-----
 .../functions/builtin/GridSearchMLogreg.dml        |  4 +++-
 3 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index eab7756..a8c0986 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -30,8 +30,8 @@
 # y            Matrix[Double]     ---        Input Matrix of vectors.
 # train        String             ---        Name ft of the train function to 
call via ft(trainArgs)
 # predict      String             ---        Name fp of the loss function to 
call via fp((predictArgs,B))
-# numB         Integer            ---        Maximum number of parameters in 
model B (pass the maximum because the
-#                                            size of B may vary with 
parameters like icpt
+# numB         Integer            ---        Maximum number of parameters in 
model B (pass the max because the size
+#                                            may vary with parameters like 
icpt or multi-class classification)
 # params       List[String]       ---        List of varied hyper-parameter 
names
 # paramValues  List[Unknown]      ---        List of matrices providing the 
parameter values as
 #                                            columnvectors for 
position-aligned hyper-parameters in 'params'
@@ -52,6 +52,7 @@
 # NAME         TYPE                         MEANING
 # 
----------------------------------------------------------------------------------------------------------------------
 # B            Matrix[Double]               Matrix[Double]the trained model 
with minimal loss (by the 'predict' function)
+#                                           Multi-column models are returned 
as a column-major linearized column vector
 # opt          Matrix[Double]               one-row frame w/ optimal 
hyperparameters (by 'params' position)
 
#-----------------------------------------------------------------------------------------------------------------------
 
@@ -127,10 +128,10 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] 
y, String train, String
         ltrainArgs['X'] = rbind(tmpX);
         ltrainArgs['y'] = rbind(tmpy);
         lbeta = t(eval(train, ltrainArgs));
-        cvbeta[,1:ncol(lbeta)] = cvbeta[,1:ncol(lbeta)] + lbeta;
+        cvbeta[,1:length(lbeta)] = cvbeta[,1:length(lbeta)] + matrix(lbeta, 1, 
length(lbeta));
         lpredictArgs[1] = as.matrix(testX);
         lpredictArgs[2] = as.matrix(testy);
-        cvloss += eval(predict, append(lpredictArgs,t(lbeta)));
+        cvloss += eval(predict, append(lpredictArgs, t(lbeta)));
       }
       Rbeta[i,] = cvbeta / cvk; # model averaging
       Rloss[i,] = cvloss / cvk;
@@ -145,8 +146,8 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, 
String train, String
         ltrainArgs[as.scalar(params[j])] = as.scalar(HP[i,j]);
       # b) core training/scoring and write-back
       lbeta = t(eval(train, ltrainArgs))
-      Rbeta[i,1:ncol(lbeta)] = lbeta;
-      Rloss[i,] = eval(predict, append(predictArgs,t(lbeta)));
+      Rbeta[i,1:length(lbeta)] = matrix(lbeta, 1, length(lbeta));
+      Rloss[i,] = eval(predict, append(predictArgs, t(lbeta)));
     }
   }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
index a8d1310..6cc6411 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGridSearchTest.java
@@ -38,8 +38,8 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
        private final static String TEST_DIR = "functions/builtin/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinGridSearchTest.class.getSimpleName() + "/";
        
-       private final static int rows = 400;
-       private final static int cols = 20;
+       private final static int _rows = 400;
+       private final static int _cols = 20;
        private boolean _codegen = false;
        
        @Override
@@ -106,7 +106,22 @@ public class BuiltinGridSearchTest extends 
AutomatedTestBase
                runGridSearch(TEST_NAME4, ExecMode.HYBRID, false);
        }
        
-       private void runGridSearch(String testname, ExecMode et, boolean 
codegen)
+       @Test
+       public void testGridSearchMLogreg4CP() {
+               runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE, 10, 4, false);
+       }
+       
+       @Test
+       public void testGridSearchMLogreg4Hybrid() {
+               runGridSearch(TEST_NAME2, ExecMode.HYBRID, 10, 4, false);
+       }
+       
+       
+       private void runGridSearch(String testname, ExecMode et, boolean 
codegen) {
+               runGridSearch(testname, et, _cols, 2, codegen); //binary 
classification
+       }
+       
+       private void runGridSearch(String testname, ExecMode et, int cols, int 
nc, boolean codegen)
        {
                ExecMode modeOld = setExecMode(et);
                _codegen = codegen;
@@ -117,8 +132,9 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
        
                        fullDMLScriptName = HOME + testname + ".dml";
                        programArgs = new String[] {"-args", input("X"), 
input("y"), output("R")};
-                       double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, 
7);
-                       double[][] y = getRandomMatrix(rows, 1, 1, 2, 1, 1);
+                       double max = testname.equals(TEST_NAME2) ? nc : 2;
+                       double[][] X = getRandomMatrix(_rows, cols, 0, 1, 0.8, 
7);
+                       double[][] y = getRandomMatrix(_rows, 1, 1, max, 1, 1);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("y", y, true);
                        
diff --git a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml 
b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
index ec2bf9d..ce54d5d 100644
--- a/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
+++ b/src/test/scripts/functions/builtin/GridSearchMLogreg.dml
@@ -26,6 +26,7 @@ accuracy = function(Matrix[Double] X, Matrix[Double] y, 
Matrix[Double] B) return
 
 X = read($1);
 y = round(read($2));
+nc = max(y);
 
 N = 200;
 Xtrain = X[1:N,];
@@ -36,10 +37,11 @@ ytest = y[(N+1):nrow(X),];
 params = list("icpt", "reg", "maxii");
 paramRanges = list(seq(0,2),10^seq(1,-6), 10^seq(1,3));
 trainArgs = list(X=Xtrain, Y=ytrain, icpt=-1, reg=-1, tol=1e-9, maxi=100, 
maxii=-1, verbose=FALSE);
-[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", 
predict="accuracy", numB=ncol(X)+1,
+[B1,opt] = gridSearch(X=Xtrain, y=ytrain, train="multiLogReg", 
predict="accuracy", numB=(ncol(X)+1)*(nc-1),
   params=params, paramValues=paramRanges, trainArgs=trainArgs, verbose=TRUE);
 B2 = multiLogReg(X=Xtrain, Y=ytrain, verbose=TRUE);
 
+B1 = matrix(B1, nrow(B1)/(nc-1), (nc-1), FALSE)
 l1 = accuracy(Xtest, ytest, B1);
 l2 = accuracy(Xtest, ytest, B2);
 R = as.scalar(l1 < l2);

Reply via email to