This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4fa8b122ed [SYSTEMDS-3153] Fix KNN
4fa8b122ed is described below
commit 4fa8b122eddefffd4d291bcea2011fdcc8c485f1
Author: Christina Dionysio <[email protected]>
AuthorDate: Wed Oct 18 13:44:07 2023 +0200
[SYSTEMDS-3153] Fix KNN
fixes the sampling method for missing value imputation using KNN
Closes #1925
---
scripts/builtin/imputeByKNN.dml | 107 ++++++++-------------
.../builtin/part1/BuiltinImputeKNNTest.java | 12 ++-
src/test/scripts/functions/builtin/imputeByKNN.dml | 8 +-
3 files changed, 52 insertions(+), 75 deletions(-)
diff --git a/scripts/builtin/imputeByKNN.dml b/scripts/builtin/imputeByKNN.dml
index 240631be47..13136ff2c9 100644
--- a/scripts/builtin/imputeByKNN.dml
+++ b/scripts/builtin/imputeByKNN.dml
@@ -19,7 +19,6 @@
#
#-------------------------------------------------------------
-
# Imputes missing values, indicated by NaNs, using KNN-based methods
# (k-nearest neighbors by euclidean distance). In order to avoid NaNs in
# distance computation and meaningful nearest neighbor search, we initialize
@@ -50,13 +49,12 @@
# result Imputed dataset
#
------------------------------------------------------------------------------
-m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1,
Double sample_frac = 0.1)
+m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1,
Double sample_frac=0.1)
return(Matrix[Double] result)
{
#KNN-Imputation Script
- #Create a mask for placeholder and to check for missing values
- masked = is.nan(X)
+ imputedValue = X
#Impute NaN value with temporary mean value of the column
filled_matrix = imputeByMean(X, matrix(0, cols = ncol(X), rows = 1))
@@ -66,103 +64,76 @@ m_imputeByKNN = function(Matrix[Double] X, String
method="dist", Int seed=-1, Do
distance_matrix = dist(filled_matrix)
#Change 0 value so rowIndexMin will ignore that diagonal value
- distance_matrix = replace(target = distance_matrix, pattern = 0,
replacement = 999)
+ distance_matrix = replace(target=distance_matrix, pattern=0,
replacement=999)
#Get the minimum distance row-wise computation
minimum_index = rowIndexMin(distance_matrix)
#Create aligned matrix from minimum index
- aligned = table(minimum_index, seq(1, nrow(X)), odim1 = nrow(X), odim2 =
nrow(X))
+ aligned = table(minimum_index, seq(1, nrow(X)), odim1=nrow(X),
odim2=nrow(X))
#Get the X records that need to be imputed
imputedValue = t(filled_matrix) %*% aligned
-
- #Update the mask value
- masked = t(imputedValue) * masked
+ imputedValue = t(imputedValue)
}
else if(method == "dist_missing") {
#assuming small missing values
- #Split the matrix into containing NaN values (missing records) and not
containing NaN values (M2 records)
- I = (rowSums(is.nan(X))!=0)
- missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
-
- Y = (rowSums(is.nan(X))==0)
- M2 = removeEmpty(target=filled_matrix, margin = "rows", select = Y)
-
- #Calculate the euclidean distance between fully records and missing
records, and then find the min value row wise
- dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing))
- dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols =
nrow(M2)))
- D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing)))
- minD = rowIndexMin(t(D))
-
- #Get the index location of the missing value
- pos = rowMaxs(is.nan(X))
- missing_indices = seq(1, nrow(pos)) * pos
-
- #Put the replacement value in the missing indices
- I2 = removeEmpty(target=missing_indices, margin="rows")
- R = table(I2,1,minD,odim1 = nrow(X), odim2=1)
-
- #Replace the 0 to avoid error in table()
- R = replace(target = R, pattern = 0, replacement = nrow(X)+1)
-
- #Create aligned matrix from minimum index
- aligned = table(R, seq(1, nrow(X)), odim1 = nrow(X), odim2 = nrow(X))
-
- #Reshape the subset
- reshaped = rbind(M2, matrix(0, rows = nrow(X) - nrow(M2), cols = ncol(X)))
-
- #Get the M2 records that need to be imputed
- imputedValue = t(reshaped) %*% aligned
-
- #Update the mask value
- masked = t(imputedValue) * masked
+ imputedValue = compute_missing_values(X, filled_matrix, seed, 1.0)
}
else if(method == "dist_sample"){
#assuming large missing values
+ imputedValue = compute_missing_values(X, filled_matrix, seed, sample_frac)
+ }
+ else {
+ stop("Method is unknown or not yet implemented")
+ }
+
+ #Impute the value
+ result = replace(target=X, pattern=NaN, replacement=0)
+ result = result + (imputedValue * is.nan(X))
+}
+
+compute_missing_values = function (Matrix[Double] X, Matrix[Double]
filled_matrix, Int seed, Double sample_frac)
+ return (Matrix[Double] imputedValue)
+{
#Split the matrix into containing NaN values (missing records) and not
containing NaN values (M2 records)
- I = rowSums(is.nan(X)) != 0
+ maskNAN = is.nan(X)
+ I = rowSums(maskNAN) != 0
missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
- #Create permutation matrix for sampling sample_frac*nrow(X) rows
- I = rand(rows=nrow(X), cols=1, seed=seed) <= sample_frac;
- subset = removeEmpty(target=filled_matrix, margin="rows", select=I);
+ Y = (rowSums(maskNAN)==0)
+ M2 = removeEmpty(target=X, margin = "rows", select = Y)
+
+ if (sample_frac != 1.0) {
+ #Create permutation matrix for sampling sample_frac*nrow(X) rows
+ I = rand(rows=nrow(M2), cols=1, seed=seed) <= sample_frac;
+ M2 = removeEmpty(target=M2, margin="rows", select=I);
+ }
#Calculate the euclidean distance between fully records and missing
records, and then find the min value row wise
- dotSubset = rowSums(subset * subset) %*% matrix(1, rows = 1, cols =
nrow(missing))
- dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols =
nrow(subset)))
- D = sqrt(dotSubset + dotMissing - 2 * (subset %*% t(missing)))
+ dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing))
+ dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols =
nrow(M2)))
+ D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing)))
minD = rowIndexMin(t(D))
#Get the index location of the missing value
- pos = rowMaxs(is.nan(X))
+ pos = rowMaxs(maskNAN)
missing_indices = seq(1, nrow(pos)) * pos
#Put the replacement value in the missing indices
I2 = removeEmpty(target=missing_indices, margin="rows")
- R = table(I2,1,minD,odim1 = nrow(X), odim2=1)
+ R = table(I2, 1, minD, odim1=nrow(X), odim2=1)
#Replace the 0 to avoid error in table()
- R = replace(target = R, pattern = 0, replacement = nrow(X)+1)
+ R = replace(target=R, pattern=0, replacement=nrow(X)+1)
#Create aligned matrix from minimum index
- aligned = table(R, seq(1, nrow(X)), odim1 = nrow(X), odim2 = nrow(X))
+ aligned = table(R, seq(1, nrow(X)), odim1=nrow(X), odim2=nrow(X))
#Reshape the subset
- reshaped = rbind(subset, matrix(0, rows = nrow(X) - nrow(subset), cols =
ncol(X)))
+ reshaped = rbind(M2, matrix(0, rows=nrow(X) - nrow(M2), cols=ncol(X)))
#Get the subset records that need to be imputed
imputedValue = t(reshaped) %*% aligned
-
- #Update the mask value
- masked = t(imputedValue) * masked
- }
- else {
- print("Method is unknown or not yet implemented")
- }
-
- #Impute the value
- result = replace(target = X, pattern = NaN, replacement = 0)
- result = result + masked
+ imputedValue = t(imputedValue)
}
-
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
index 627437fe76..2b7c422978 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
@@ -41,34 +41,36 @@ public class BuiltinImputeKNNTest extends AutomatedTestBase
{
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME, new String[] {"B","B2"}));
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME, new String[] {"B","B2","B3"}));
}
@Test
public void testDefaultCP()throws IOException{
- runImputeKNN(true, Types.ExecType.CP);
+ runImputeKNN(Types.ExecType.CP);
}
@Test
public void testDefaultSpark()throws IOException{
- runImputeKNN(true, Types.ExecType.SPARK);
+ runImputeKNN(Types.ExecType.SPARK);
}
- private void runImputeKNN(boolean defaultProb, ExecType instType) throws
IOException {
+ private void runImputeKNN(ExecType instType) throws IOException {
ExecMode platform_old = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-args", DATASET_DIR+"Salaries.csv",
- "dist", "dist_missing", output("B"), output("B2")};
+ "dist", "dist_missing", "dist_sample", "42", "0.9",
output("B"), output("B2"), output("B3")};
runTest(true, false, null, -1);
//Compare matrices, check if the sum of the imputed value is
roughly the same
double sum1 = readDMLMatrixFromOutputDir("B").get(new
CellIndex(1,1));
double sum2 = readDMLMatrixFromOutputDir("B2").get(new
CellIndex(1,1));
+ double sum3 = readDMLMatrixFromOutputDir("B3").get(new
CellIndex(1,1));
Assert.assertEquals(sum1, sum2, eps);
+ Assert.assertEquals(sum2, sum3, eps);
}
finally {
rtplatform = platform_old;
diff --git a/src/test/scripts/functions/builtin/imputeByKNN.dml
b/src/test/scripts/functions/builtin/imputeByKNN.dml
index 299c1dda30..0e87026e2b 100644
--- a/src/test/scripts/functions/builtin/imputeByKNN.dml
+++ b/src/test/scripts/functions/builtin/imputeByKNN.dml
@@ -28,15 +28,19 @@ mask = is.nan(X)
#Perform the KNN imputation
result = imputeByKNN(X = X, method = $2)
result2 = imputeByKNN(X = X, method = $3)
+result3 = imputeByKNN(X = X, method = $4, seed = $5, sample_frac = $6)
#Get the imputed value
I = (mask[,2] == 1);
value = removeEmpty(target = result, margin = "rows", select = I)
value2 = removeEmpty(target = result2, margin = "rows", select = I)
+value3 = removeEmpty(target = result3, margin = "rows", select = I)
#Get the sum of the imputed value
value = colSums(value[,2])
value2 = colSums(value2[,2])
+value3 = colSums(value3[,2])
-write(value, $4)
-write(value2, $5)
\ No newline at end of file
+write(value, $7)
+write(value2, $8)
+write(value3, $9)