This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit ced2ae8212124b49afb1dff330e2592131e324d0 Author: Matthias Boehm <[email protected]> AuthorDate: Sat Jan 30 17:48:29 2021 +0100 [MINOR] Fix robustness multiLogReg and multiLogRegPredict * Handling of missing values in the feature matrix X (NaN in double matrices) by replacing NaN with 0 to avoid NaN gradients and intermediates, which ultimately, compute NaN models. * Cleanup multiLogReg formatting and constants * Consistent handling of Y modifications (Y<=0) in both multiLogReg and multiLogRegPredict --- scripts/builtin/multiLogReg.dml | 45 +++++++++++++--------- scripts/builtin/multiLogRegPredict.dml | 7 +++- .../java/org/apache/sysds/common/Builtins.java | 4 +- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/scripts/builtin/multiLogReg.dml b/scripts/builtin/multiLogReg.dml index 5afd5e7..1d2d060 100644 --- a/scripts/builtin/multiLogReg.dml +++ b/scripts/builtin/multiLogReg.dml @@ -46,11 +46,10 @@ # betas Double regression betas as output for prediction # ------------------------------------------------------------------------------------------- -m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, Double tol = 0.000001, - Double reg = 1.0, Integer maxi = 100, Integer maxii = 20, Boolean verbose = TRUE) +m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Int icpt = 2, + Double tol=1e-6, Double reg=1.0, Int maxi=100, Int maxii=20, Boolean verbose=TRUE) return(Matrix[Double] betas) { - eta0 = 0.0001; eta1 = 0.25; eta2 = 0.75; @@ -62,6 +61,13 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D N = nrow (X); D = ncol (X); + # Robustness for datasets with missing values (causing NaN gradients) + numNaNs = sum(isNaN(X)) + if( numNaNs > 0 ) { + print("multiLogReg: matrix X contains "+numNaNs+" missing values, replacing with 0.") + X = replace(target=X, pattern=NaN, replacement=0); + } + # Introduce the intercept, shift and rescale the columns of X if needed if (icpt == 1 | icpt == 2) { # add the intercept column X = cbind (X, matrix (1, N, 1)); @@ -82,7 +88,7 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D shift_X = - avg_X_cols * scale_X; shift_X [D, 1] = 0; rowSums_X_sq = (X ^ 2) %*% (scale_X ^ 2) + X %*% (2 * scale_X * shift_X) + sum (shift_X ^ 2); - } + } else { scale_X = matrix (1, D, 1); shift_X = matrix (0, D, 1); @@ -101,9 +107,9 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D # Convert "Y" into indicator matrix: max_y = max (Y); - if (min (Y) <= 0) { + if (min (Y) <= 0) { # Category labels "0", "-1" etc. are converted into the largest label - Y = Y + (- Y + max_y + 1) * (Y <= 0); + Y = ifelse(Y <= 0, max_y + 1, Y); max_y = max_y + 1; } Y = table (seq (1, N, 1), Y, N, max_y); @@ -192,7 +198,7 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D } inneriter = inneriter + 1; innerconverge = innerconverge | (inneriter > maxii); - } + } # END TRUST REGION SUB-PROBLEM # compute rho, update B, obtain delta gs = sum (S * Grad); @@ -202,7 +208,7 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D ssX_B_new = diag (scale_X) %*% B_new; ssX_B_new [D, ] = ssX_B_new [D, ] + t(shift_X) %*% B_new; } - else + else ssX_B_new = B_new; LT = cbind ((X %*% ssX_B_new), matrix (0, N, 1)); @@ -236,10 +242,12 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D if(verbose) { if (is_trust_boundary_reached == 1) - print ("-- Outer Iteration " + iter + ": Had " + (inneriter - 1) + " CG iterations, trust bound REACHED"); - else print ("-- Outer Iteration " + iter + ": Had " + (inneriter - 1) + " CG iterations"); - print (" -- Obj.Reduction: Actual = " + actred + ", Predicted = " + qk + - " (A/P: " + (round (10000.0 * rho) / 10000.0) + "), Trust Delta = " + delta); + print("-- Outer Iteration " + iter + ": Had " + + (inneriter - 1) + " CG iterations, trust bound REACHED"); + else + print ("-- Outer Iteration " + iter + ": Had " + (inneriter - 1) + " CG iterations"); + print (" -- Obj.Reduction: Actual = " + actred + ", Predicted = " + qk + + " (A/P: " + (round (1e4 * rho) / 1e4) + "), Trust Delta = " + delta); } if (is_rho_accepted) { @@ -254,22 +262,21 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Integer icpt = 2, D obj = obj_new; if(verbose) - print (" -- New Objective = " + obj + ", Beta Change Norm = " + snorm + ", Gradient Norm = " + norm_Grad); + print(" -- New Objective = " + obj + ", Beta Change Norm = " + + snorm + ", Gradient Norm = " + norm_Grad); } iter = iter + 1; converge = ((norm_Grad < (tol * norm_Grad_initial)) | (iter > maxi) | - ((is_trust_boundary_reached == 0) & (abs (actred) < (abs (obj) + abs (obj_new)) * 0.00000000000001))); + ((is_trust_boundary_reached == 0) & (abs (actred) < (abs (obj) + abs (obj_new)) * 1e-14))); if (verbose & converge) print ("Termination / Convergence condition satisfied."); } if (icpt == 2) { - B_out = diag (scale_X) %*% B; - B_out [D, ] = B_out [D, ] + t(shift_X) %*% B; + betas = diag (scale_X) %*% B; + betas[D,] = betas[D,] + t(shift_X) %*% B; } else { - B_out = B; + betas = B; } - - betas = B_out } diff --git a/scripts/builtin/multiLogRegPredict.dml b/scripts/builtin/multiLogRegPredict.dml index 3756420..74cbbda 100644 --- a/scripts/builtin/multiLogRegPredict.dml +++ b/scripts/builtin/multiLogRegPredict.dml @@ -43,8 +43,11 @@ m_multiLogRegPredict = function(Matrix[Double] X, Matrix[Double] B, Matrix[Double] Y, Boolean verbose = FALSE) return(Matrix[Double] M, Matrix[Double] predicted_Y, Double accuracy) { - if(min(Y) <= 0) - stop("multiLogRegPredict: class labels should be greater than zero") + if(min(Y) <= 0) { + print("multiLogRegPredict: class labels should be greater than " + + "zero - converting all labels <= 0 to max(Y)+1"); + Y = ifelse(Y <= 0, max(Y) + 1, Y); + } if(ncol(X) < nrow(B)-1) stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+" "+nrow(B)); diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index c76e5ae..2d5659d 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -131,8 +131,8 @@ public enum Builtins { INTERSECT("intersect", true), INVERSE("inv", "inverse", false), IQM("interQuartileMean", false), - ISNA("is.na", false), - ISNAN("is.nan", false), + ISNA("is.na", "isNA", false), + ISNAN("is.nan", "isNaN", false), ISINF("is.infinite", false), KMEANS("kmeans", true), KMEANSPREDICT("kmeansPredict", true),
