Repository: systemml Updated Branches: refs/heads/master 5069f9781 -> 1b1c3fea3
[SYSTEMML-2332] Simplification and performance ALS-predict script This patch simplifies the ALS-predict script by removing unnecessary, and even counter-productive script-level "optimizations" which stem from a time when we did not have sparsity-exploiting fused operators. On a scenario of a 100Kx100 and 20Kx100 factors and 10K queries, this modification improved the end-to-end runtime (incl matrix read and write) from 20.7s to 2.6s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1b1c3fea Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1b1c3fea Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1b1c3fea Branch: refs/heads/master Commit: 1b1c3fea355fc39d4db8a8229fbd2d36c11e4258 Parents: 5069f97 Author: Matthias Boehm <[email protected]> Authored: Fri May 18 21:32:55 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri May 18 21:32:55 2018 -0700 ---------------------------------------------------------------------- scripts/algorithms/ALS_predict.dml | 49 +++++++++++---------------------- 1 file changed, 16 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1b1c3fea/scripts/algorithms/ALS_predict.dml ---------------------------------------------------------------------- diff --git a/scripts/algorithms/ALS_predict.dml b/scripts/algorithms/ALS_predict.dml index 0af8301..a4e6bd4 100644 --- a/scripts/algorithms/ALS_predict.dml +++ b/scripts/algorithms/ALS_predict.dml @@ -19,7 +19,6 @@ # #------------------------------------------------------------- -# # THIS SCRIPT COMPUTES THE RATING/SCORE FOR A GIVEN LIST OF PAIRS: (USER-ID, ITEM-ID) USING 2 FACTOR MATRICES L AND R # WE ASSUME THAT ALL USERS HAVE RATED AT LEAST ONCE AND ALL ITEMS HAVE BEEN RATED AT LEAST ONCE. # INPUT PARAMETERS: @@ -27,7 +26,7 @@ # NAME TYPE DEFAULT MEANING # --------------------------------------------------------------------------------------------- # X String --- The input user-id/item-id list -# Y String --- The output user-id/item-id/score +# Y String --- The output user-id/item-id/score # L String --- Location of the factor matrix L: user-id x feature-id # R String --- Location to the factor matrix R: feature-id x item-id # Vrows Integer --- The number of rows in the original matrix @@ -37,16 +36,16 @@ # OUTPUT: Matrix Y containing the predicted ratings for users and items specified in input matrix X # # HOW TO INVOKE THIS SCRIPT - EXAMPLE: -# hadoop jar SystemML.jar -f ALS-predict.dml -nvargs Vrows=100000 Vcols=10000 X=INPUT_DIR/X L=OUTPUT_DIR/L R=OUTPUT_DIR/R -# Y=OUTPUT_DIR/Y fmt=csv +# hadoop jar SystemML.jar -f ALS-predict.dml -nvargs Vrows=100000 Vcols=10000 \ +# X=INPUT_DIR/X L=OUTPUT_DIR/L R=OUTPUT_DIR/R Y=OUTPUT_DIR/Y fmt=csv -fileX = $X; -fileY = $Y; -fileL = $L; -fileR = $R; -Vrows = $Vrows; -Vcols = $Vcols; -fmtO = ifdef ($fmt, "text"); # $fmt="text"; +fileX = $X; +fileY = $Y; +fileL = $L; +fileR = $R; +Vrows = $Vrows; +Vcols = $Vcols; +fmtO = ifdef ($fmt, "text"); X = read (fileX); L = read (fileL); @@ -56,8 +55,8 @@ R = read (fileR); n = nrow (X); m = ncol (X); -if (m != 2){ - stop("The input matrix must have 2 columns: user-id and item-id"); +if (m != 2) { + stop("The input matrix must have 2 columns: user-id and item-id"); } Lrows = nrow (L); @@ -66,33 +65,17 @@ Rcols = ncol (R); X_user_max = max (X[,1]); X_item_max = max (X[,2]); -# initializing Y matrix -Y = matrix(0, rows = n, cols = 3); - if (X_user_max > Vrows | X_item_max > Vcols ) { - stop ("Predictions cannot be provided. Maximum user-id (item-id) exceed the number of rows (columns) of V."); + stop ("Predictions cannot be provided. Maximum user-id (item-id) exceed the number of rows (columns) of V."); } if (Lrows != Vrows | Rcols != Vcols) { - stop ("Predictions cannot be provided. Number of rows of L (columns of R) does not match the number of rows (column) of V."); + stop ("Predictions cannot be provided. Number of rows of L (columns of R) does not match the number of rows (column) of V."); } - # user2item table -ones = matrix (1, rows = n, cols = 1); -UI = table (X[,1], X[,2], ones, Vrows, Vcols); - -# summing up over all items for all users -U = rowSums (UI) - -# replacing all rows > 1 with 1 -U = U >= 1; - -# selecting users from factor L -U_prime = L * U; - -V_prime = (U_prime %*% R); +UI = table (X[,1], X[,2], Vrows, Vcols); # Applying items filter -V_prime = UI * V_prime; +V_prime = UI * (L %*% R); write(V_prime, fileY, format = fmtO);
