Repository: systemml Updated Branches: refs/heads/master a6bca8851 -> 0aaf11d82
[MINOR] Various script simplifications Kmeans predict Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/56c81cbd Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/56c81cbd Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/56c81cbd Branch: refs/heads/master Commit: 56c81cbde3a71ccc49505d2f4f5a89bb0661fd9b Parents: a6bca88 Author: Matthias Boehm <mboe...@gmail.com> Authored: Mon Jul 16 13:51:02 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Jul 16 19:36:41 2018 -0700 ---------------------------------------------------------------------- scripts/algorithms/Kmeans-predict.dml | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/56c81cbd/scripts/algorithms/Kmeans-predict.dml ---------------------------------------------------------------------- diff --git a/scripts/algorithms/Kmeans-predict.dml b/scripts/algorithms/Kmeans-predict.dml index 17e673a..5f045d0 100644 --- a/scripts/algorithms/Kmeans-predict.dml +++ b/scripts/algorithms/Kmeans-predict.dml @@ -121,7 +121,6 @@ if (fileC != " ") { print ("Reading C..."); C = read (fileC); num_clusters = nrow (C); - ones_C = matrix (1, rows = num_clusters, cols = 1); print ("Computing the predicted Y..."); D = -2 * (X %*% t(C)) + t(rowSums (C ^ 2)); prY = rowIndexMin (D); @@ -133,22 +132,19 @@ if (fileC != " ") { print ("Reading the predicted Y..."); prY = read (filePrY); num_clusters = max (prY); - ones_C = matrix (1, rows = num_clusters, cols = 1); } if (fileX != " ") { print ("Computing the WCSS..."); # Compute projection matrix from clusters to records - P = matrix (0, rows = nrow (X), cols = num_clusters); - P [, 1 : max (prY)] = table (seq (1, nrow (X), 1), prY); + P = table (seq (1, nrow (X), 1), prY, nrow(X), num_clusters); # Compute the means, as opposed to the centroids cluster_sizes = t(colSums (P)); - record_of_ones = matrix (1, rows = 1, cols = ncol (X)); - M = (t(P) %*% X) / ((cluster_sizes + (cluster_sizes == 0)) %*% record_of_ones); + M = (t(P) %*% X) / (cluster_sizes + (cluster_sizes == 0)); # Compute the WCSS for the means wcss_means = sum ((X - P %*% M) ^ 2); wcss_means_pc = 100.0 * wcss_means / total_ss; - bcss_means = sum (cluster_sizes * rowSums ((M - ones_C %*% total_mean) ^ 2)); + bcss_means = sum (cluster_sizes * rowSums ((M - total_mean) ^ 2)); bcss_means_pc = 100.0 * bcss_means / total_ss; # Output results print ("Total Sum of Squares (TSS) = " + total_ss); @@ -166,7 +162,7 @@ if (fileC != " ") { # Compute the WCSS for the centroids wcss_centroids = sum ((X - P %*% C) ^ 2); wcss_centroids_pc = 100.0 * wcss_centroids / total_ss; - bcss_centroids = sum (cluster_sizes * rowSums ((C - ones_C %*% total_mean) ^ 2)); + bcss_centroids = sum (cluster_sizes * rowSums ((C - total_mean) ^ 2)); bcss_centroids_pc = 100.0 * bcss_centroids / total_ss; # Output results print ("WCSS for centroids: " + (round (10000.0 * wcss_centroids_pc) / 10000.0) + "% of TSS = " + wcss_centroids); @@ -323,15 +319,12 @@ return (Matrix[double] row_ids, Matrix[double] col_ids, Matrix[double] margins, Matrix[double] max_counts, Matrix[double] rounded_percentages) { margins = rowSums (counts); - select_positive = diag (margins > 0); - select_positive = removeEmpty (target = select_positive, margin = "rows"); + select_positive = removeEmpty (target = diag (margins > 0), margin = "rows"); row_ids = select_positive %*% seq (1, nrow (margins), 1); pos_counts = select_positive %*% counts; pos_margins = select_positive %*% margins; max_counts = rowMaxs (pos_counts); - one_per_column = matrix (1, rows = 1, cols = ncol (pos_counts)); - max_counts_ppred = max_counts %*% one_per_column; - is_max_count = (pos_counts == max_counts_ppred); + is_max_count = (pos_counts == max_counts); aggr_is_max_count = t(cumsum (t(is_max_count))); col_ids = rowSums (aggr_is_max_count == 0) + 1; rounded_percentages = round (1000000.0 * max_counts / pos_margins) / 10000.0;