This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d1bc4eb84f [SYSTEMDS-3153] Missing value imputation via KNN-based
methods
d1bc4eb84f is described below
commit d1bc4eb84fb2d9b89434d3b5368e53e8fbf55f5f
Author: regaleo605 <[email protected]>
AuthorDate: Tue Aug 15 20:39:53 2023 +0200
[SYSTEMDS-3153] Missing value imputation via KNN-based methods
LDE project SoSe'23.
Closes #1879.
---
scripts/builtin/imputeByKNN.dml | 186 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../builtin/part1/BuiltinImputeKNNTest.java | 77 +++++++++
src/test/scripts/functions/builtin/imputeByKNN.dml | 44 +++++
4 files changed, 308 insertions(+)
diff --git a/scripts/builtin/imputeByKNN.dml b/scripts/builtin/imputeByKNN.dml
new file mode 100644
index 0000000000..b560c94799
--- /dev/null
+++ b/scripts/builtin/imputeByKNN.dml
@@ -0,0 +1,186 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Imputes missing values, indicated by NaNs, using KNN-based methods
+# (k-nearest neighbors by euclidean distance). In order to avoid NaNs in
+# distance computation and meaningful nearest neighbor search, we initialize
+# the missing values by column means. Currently, only the column with the most
+# missing values is actually imputed.
+#
+#
------------------------------------------------------------------------------
+# INPUT:
+#
------------------------------------------------------------------------------
+# X Matrix with missing values, which are represented as NaNs
+# method Method used for imputing missing values with different performance
+# and accuracy tradeoffs:
+# 'dist' (default): Compute all-pairs distances and impute the
+# missing values by closest. O(N^2 * #features)
+# 'dist_missing': Compute distances between data and records with
+# missing values. O(N*M * #features), assuming
+# that the number of records with MV is M<<N.
+# 'dist_sample': Compute distances between sample of data and
+# records with missing values. O(S*M * #features)
+# with M<<N and S<<N, but suboptimal imputation.
+# seed Root seed value for random/sample calls for deterministic behavior
+# -1 for true randomization
+#
------------------------------------------------------------------------------
+#
+# OUTPUT:
+#
------------------------------------------------------------------------------
+# result Imputed dataset
+#
------------------------------------------------------------------------------
+
+m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1)
+ return(Matrix[Double] result)
+{
+ #TODO fix seed handling (only root seed)
+ #TODO fix imputation for all columns with missing values
+
+ #KNN-Imputation Script
+
+ #Create a mask for placeholder and to check for missing values
+ masked = is.nan(X)
+
+ #Find the column containing multiple missing values
+ missing_col = rowIndexMax(colSums(is.nan(X)))
+
+ #Impute NaN value with temporary mean value of the column
+ filled_matrix = imputeByMean(X, matrix(0, cols = ncol(X), rows = 1))
+
+ if(method == "dist") {
+ #Calculate the distance using dist method after imputation with mean
+ distance_matrix = dist(filled_matrix)
+
+ #Change 0 value so rowIndexMin will ignore that diagonal value
+ distance_matrix = replace(target = distance_matrix, pattern = 0,
replacement = 999)
+
+ #Get the minimum distance row-wise computation
+ minimum_index = rowIndexMin(distance_matrix)
+
+ #Position of missing values in per row in which column
+ position = rowSums(is.nan(X))
+ position = position * minimum_index
+
+ #Filter the 0 out
+ I = (rowSums(is.nan(X))!=0)
+ missing = removeEmpty(target=position, margin="rows", select=I)
+
+ #Convert the value indices into 0/1 matrix to find location
+ indices = table(missing,
seq(1,nrow(filled_matrix)),odim1=nrow(filled_matrix),odim2=nrow(missing))
+
+ #Replace the index with value
+ imputedValue = t(indices) %*% filled_matrix[,as.scalar(missing_col)]
+
+ #Get the index location of the missing value
+ pos = rowSums(is.nan(X))
+ missing_indices = seq(1, nrow(pos)) * pos
+
+ #Put the replacement value in the missing indices
+ I2 = removeEmpty(target=missing_indices, margin="rows")
+ R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1)
+
+ #Replace the masked column with to be imputed Value
+ masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R
+ }
+ else if(method == "dist_missing") {
+ #assuming small missing values
+ #Split the matrix into containing NaN values (missing records) and not
containing NaN values (M2 records)
+ I = (rowSums(is.nan(X))!=0)
+ missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
+
+ Y = (rowSums(is.nan(X))==0)
+ M2 = removeEmpty(target=filled_matrix, margin = "rows", select = Y)
+
+ #Calculate the euclidean distance between fully records and missing
records, and then find the min value row wise
+ dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing))
+ dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols =
nrow(M2)))
+ D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing)))
+ minD = rowIndexMin(t(D))
+
+ #Convert the value indices into 0/1 matrix to find location
+ indices = table(minD, seq(1,nrow(M2)),odim1=nrow(M2),odim2=nrow(minD))
+
+ #Replace the value
+ imputedValue = t(indices) %*% M2[,as.scalar(missing_col)]
+
+ #Get the index location of the missing value
+ pos = rowSums(is.nan(X))
+ missing_indices = seq(1, nrow(pos)) * pos
+
+ #Put the replacement value in the missing indices
+ I2 = removeEmpty(target=missing_indices, margin="rows")
+ R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1)
+
+ #Update the masked value
+ masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R
+ }
+ else if(method == "dist_sample"){
+ #assuming large missing values
+ #Split the matrix into containing NaN values (missing records) and not
containing NaN values (M2 records)
+ I = (rowSums(is.nan(X))!=0)
+ missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
+
+ Y = (rowSums(is.nan(X))==0)
+ M3 = removeEmpty(target=filled_matrix, margin = "rows", select = Y)
+
+ #Create a random subset
+ random_matrix = ceiling(rand(rows = nrow(M3), cols = 1, min = 0, max = 1,
sparsity = 0.5, seed = seed))
+
+ #ensure that random_matrix has at least 1 value
+ if(as.scalar(colSums(random_matrix)) < 1)
+ random_matrix = matrix(1, rows = nrow(M3), cols = 1)
+
+ subset = M3 * random_matrix
+ subset = removeEmpty(target=subset, margin = "rows", select =
random_matrix)
+
+ #Calculate the euclidean distance between fully records and missing
records, and then find the min value row wise
+ dotSubset = rowSums(subset * subset) %*% matrix(1, rows = 1, cols =
nrow(missing))
+ dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols =
nrow(subset)))
+ D = sqrt(dotSubset + dotMissing - 2 * (subset %*% t(missing)))
+ minD = rowIndexMin(t(D))
+
+ #Convert the value indices into 0/1 matrix to find location
+ indices = table(minD,
seq(1,nrow(subset)),odim1=nrow(subset),odim2=nrow(minD))
+
+ #Replace the value
+ imputedValue = t(indices) %*% subset[,as.scalar(missing_col)]
+
+ #Get the index location of the missing value
+ pos = rowSums(is.nan(X))
+ missing_indices = seq(1, nrow(pos)) * pos
+
+ #Put the replacement value in the missing indices
+ I2 = removeEmpty(target=missing_indices, margin="rows")
+ R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1)
+
+ #Update the masked value
+ masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R
+ }
+ else {
+ print("Method is unknown or not yet implemented")
+ }
+
+ #Impute the value
+ result = replace(target = X, pattern = NaN, replacement = 0)
+ result = result + masked
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 883e57aa22..2243eeb963 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -166,6 +166,7 @@ public enum Builtins {
IMG_INVERT("img_invert", true),
IMG_POSTERIZE("img_posterize", true),
IMPURITY_MEASURES("impurityMeasures", true),
+ IMPUTE_BY_KNN("imputeByKNN", true),
IMPUTE_BY_MEAN("imputeByMean", true),
IMPUTE_BY_MEAN_APPLY("imputeByMeanApply", true),
IMPUTE_BY_MEDIAN("imputeByMedian", true),
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
new file mode 100644
index 0000000000..627437fe76
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class BuiltinImputeKNNTest extends AutomatedTestBase {
+
+ private final static String TEST_NAME = "imputeByKNN";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinImputeKNNTest.class.getSimpleName() + "/";
+
+ private double eps = 10;
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME, new String[] {"B","B2"}));
+ }
+
+ @Test
+ public void testDefaultCP()throws IOException{
+ runImputeKNN(true, Types.ExecType.CP);
+ }
+
+ @Test
+ public void testDefaultSpark()throws IOException{
+ runImputeKNN(true, Types.ExecType.SPARK);
+ }
+
+ private void runImputeKNN(boolean defaultProb, ExecType instType) throws
IOException {
+ ExecMode platform_old = setExecMode(instType);
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", DATASET_DIR+"Salaries.csv",
+ "dist", "dist_missing", output("B"), output("B2")};
+
+ runTest(true, false, null, -1);
+
+ //Compare matrices, check if the sum of the imputed value is
roughly the same
+ double sum1 = readDMLMatrixFromOutputDir("B").get(new
CellIndex(1,1));
+ double sum2 = readDMLMatrixFromOutputDir("B2").get(new
CellIndex(1,1));
+ Assert.assertEquals(sum1, sum2, eps);
+ }
+ finally {
+ rtplatform = platform_old;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/imputeByKNN.dml
b/src/test/scripts/functions/builtin/imputeByKNN.dml
new file mode 100644
index 0000000000..f02ac90eac
--- /dev/null
+++ b/src/test/scripts/functions/builtin/imputeByKNN.dml
@@ -0,0 +1,44 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Prepare the data
+X = read($1, data_type="frame", format="csv", header=TRUE, naStrings= ["20"]);
+X = cbind(as.matrix(X[,4:5]), as.matrix(X[,7]))
+remove_col = is.nan(X)
+
+data = removeEmpty(target = X, margin = "rows", select = (remove_col[,1] != 1))
+mask = is.nan(data)
+
+#Perform the KNN imputation
+result = imputeByKNN(X = data, method = $2)
+result2 = imputeByKNN(X = data, method = $3)
+
+#Get the imputed value
+I = (mask[,2] == 1);
+value = removeEmpty(target = result, margin = "rows", select = I)
+value2 = removeEmpty(target = result2, margin = "rows", select = I)
+
+#Get the sum of the imputed value
+value = colSums(value[,2])
+value2 = colSums(value2[,2])
+
+write(value, $4)
+write(value2, $5)