This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new ad89941  [SYSTEMDS-2869] Built-in functions KNN and KNNBF (k-nearest 
neighbors)
ad89941 is described below

commit ad899416d3d583df9bf4c6c16755e0bd39382385
Author: ywcb00 <[email protected]>
AuthorDate: Sun Feb 21 00:30:46 2021 +0100

    [SYSTEMDS-2869] Built-in functions KNN and KNNBF (k-nearest neighbors)
    
    DIA project WS2020/21.
    Closes #2869.
    
    Co-authored-by: Metka Batič <[email protected]>
    Co-authored-by: Matthias Kargl <[email protected]>
---
 .github/workflows/functionsTests.yml               |   3 +-
 scripts/builtin/knn.dml                            | 639 +++++++++++++++++++++
 scripts/builtin/knnbf.dml                          |  58 ++
 .../java/org/apache/sysds/common/Builtins.java     |   4 +-
 src/test/java/org/apache/sysds/test/TestUtils.java | 409 ++++++-------
 .../test/functions/builtin/BuiltinKNNBFTest.java   | 118 ++++
 .../test/functions/builtin/BuiltinKNNTest.java     | 130 +++++
 src/test/scripts/functions/builtin/knn.R           |  52 ++
 src/test/scripts/functions/builtin/knn.dml         |  35 ++
 src/test/scripts/functions/builtin/knnbf.dml       |  28 +
 .../scripts/functions/builtin/knnbfReference.dml   |  29 +
 src/test/scripts/installDependencies.R             |   2 +
 12 files changed, 1312 insertions(+), 195 deletions(-)

diff --git a/.github/workflows/functionsTests.yml 
b/.github/workflows/functionsTests.yml
index cf64a2f..5e7466c 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -46,7 +46,8 @@ jobs:
           "**.functions.builtin.**",
           
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
           
"**.functions.dnn.**,**.functions.misc.**,**.functions.mlcontext.**,**.functions.paramserv.**",
-          
"**.functions.nary.**,**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.quaternary.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
+          "**.functions.nary.**,**.functions.quaternary.**",
+          
"**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
           
"**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**,**.functions.transform.**",
           "**.functions.unary.matrix.**"
           ]
diff --git a/scripts/builtin/knn.dml b/scripts/builtin/knn.dml
new file mode 100644
index 0000000..8e86ba3
--- /dev/null
+++ b/scripts/builtin/knn.dml
@@ -0,0 +1,639 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# THIS SCRIPT IMPLEMENTS KNN( K Nearest Neighbor ) ALGORITHM
+#
+# INPUT   PARAMETERS:
+# 
---------------------------------------------------------------------------------------------
+# NAME    TYPE     DEFAULT     OPTIONAL     MEANING
+# 
---------------------------------------------------------------------------------------------
+# Train                 Matrix   ---    N   The input matrix as features
+# Test                  Matrix   ---    N   The input matrix for nearest 
neighbor search
+# CL                    Matrix   ---    Y   The input matrix as target
+# CL_T                  Integer  0      Y   The target type of matrix CL 
whether
+#                                           columns in CL are continuous ( =1 
) or
+#                                           categorical ( =2 ) or
+#                                           not specified ( =0 )
+# trans_continuous      Boolean  FALSE  Y   Option flag for continuous feature 
transformed to [-1,1]:
+#                                           FALSE = do not transform 
continuous variable;
+#                                           TRUE = transform continuous 
variable;
+# k_value               int      5      Y   k value for KNN, ignore if 
select_k enable
+# select_k              Boolean  FALSE  Y   Use k selection algorithm to 
estimate k
+#                                           ( TRUE means yes )
+# k_min                 int      1      Y   Min k value(  available if 
select_k = 1 )
+# k_max                 int      100    Y   Max k value(  available if 
select_k = 1 )
+# select_feature        Boolean  FALSE  Y   Use feature selection algorithm to 
select feature
+#                                           ( TRUE means yes )
+# feature_max           int      10     Y   Max feature selection
+# interval              int      1000   Y   Interval value for K selecting (  
available if select_k = 1 )
+# feature_importance    Boolean  FALSE      Y   Use feature importance 
algorithm to estimate each feature
+#                                           ( TRUE means yes )
+# predict_con_tg        int      0      Y   Continuous  target predict 
function: mean(=0) or
+#                                           median(=1)
+# START_SELECTED        Matrix   ---    Y   feature selection initinal value
+# 
---------------------------------------------------------------------------------------------
+# OUTPUT: Matrix NNR, Matrix PR, Matrix FEATURE_IMPORTANCE_VALUE
+#
+
+m_knn = function(
+    Matrix[Double] Train,
+    Matrix[Double] Test,
+    Matrix[Double] CL,
+    Integer CL_T = 0,
+    Integer trans_continuous = 0,
+    Integer k_value = 5,
+    Integer select_k = 0,
+    Integer k_min = 1,
+    Integer k_max = 100,
+    Integer select_feature = 0,
+    Integer feature_max = 10,
+    Integer interval = 1000,
+    Integer feature_importance = 0,
+    Integer predict_con_tg = 0,
+    Matrix[Double] START_SELECTED = matrix(0, 0, 0)
+)return(
+    Matrix[Double] NNR_matrix,
+    Matrix[Double] CL_matrix,
+    Matrix[Double] m_feature_importance
+){
+
+  m_feature_importance = matrix(0, 0, 0);
+
+  #data prepare
+  if( trans_continuous == 1 ){
+    Train = prepareKNNData( Train);
+    Test  = prepareKNNData( Test);
+  }
+
+  n_records = nrow( Train );
+  n_features = ncol( Train );
+  s_selected_k = 5;
+  m_selected_feature = matrix(1,rows=1,cols=n_records);
+  if( select_k == 1 | select_feature==1 ){
+    #parameter check
+    #parameter re-define
+    if( select_k==1 ){
+      if(  k_max >= n_records  ){
+        k_max = n_records - 1;
+        print( "k_max should no greater than number of record, change k_max 
equal " +
+        "( number of record - 1 ) : " + k_max );
+      }
+      if(  k_max >= interval  ){
+        interval = k_max + 1;
+        # k_max should equal interval -1, because we drop self when search nn.
+        print( "interval should be no less than k_max, change interval equal : 
" +
+        interval );
+      }
+      if(  k_max <= 1  )
+        stop( "uncorrect k_max value"  );
+      if(  k_min >= k_max )
+        stop( "k_min >= k_max" );
+    }
+    if( select_feature == 1 ){
+      if(  k_value >= n_records  ){
+        k_value = n_records - 1;
+        print( "k_value should be no greater than number of record, change 
k_value equal " +
+        "( number of record - 1 ) : " + k_value );
+      }
+      #Select feature only
+      if( nrow(START_SELECTED) == 0 & ncol(START_SELECTED) == 0 )
+        m_start_selected_feature = matrix( 0,1,n_features );
+      else
+        m_start_selected_feature = START_SELECTED;
+    }
+
+    if( select_k == 1 & select_feature == 1){
+      #Combined k and feature selection
+      print("Start combined k and feature selection ...");
+      [m_selected_feature,s_selected_k] =
+        getSelectedFeatureAndK( Train,CL,CL_T,m_start_selected_feature,
+        feature_max,k_min,k_max,interval );
+    }
+    else if( select_k == 1 ){
+      #Select k only
+      print("Start k select ...");
+      s_selected_k = getSelectedKBase( Train,CL,CL_T,k_min,k_max,interval );
+    }
+    else if( select_feature == 1 ){
+      #Select feature only
+      print("Start feature selection ... ");
+      [m_selected_feature,d_err] =
+        getSelectedFeature( Train,CL,CL_T,m_start_selected_feature,
+        feature_max,k_value,interval );
+    }
+  }
+
+  if( feature_importance == 1){
+    if(  k_value >= n_records  ){
+      k_value = n_records - 1;
+      print( "k_value should be no greater than number of record, make k_value 
equal " +
+      "( number of record - 1 ) : " + k_value );
+    }
+    [m_feature_importance,m_norm_feature_importance] =
+      getFeatureImportance(Train,CL,CL_T,k_value);
+  }
+
+  NNR_matrix = naiveKNNsearch(P=Train,Q=Test,K=k_value);
+
+  CL_matrix = matrix( 0,nrow( Test ),1 );
+
+  for(i in 1 : nrow(NNR_matrix))
+  {
+    NNR_tmp_matrix = matrix( 0,k_value,1 );
+    for( j in 1:k_value )
+      NNR_tmp_matrix[j,1] = CL[as.scalar( NNR_matrix[i,j] ),1];
+
+    if(CL_T == 2) {
+      t_cl_value = as.scalar( rowIndexMax( t(NNR_tmp_matrix) ) );
+    }
+    else {
+      if ( predict_con_tg == 0)
+        t_cl_value = mean( NNR_tmp_matrix );
+      else if(predict_con_tg == 1)
+        t_cl_value = median( NNR_tmp_matrix );
+    }
+
+    CL_matrix[i,1] = t_cl_value;
+  }
+}
+
+#naive knn search implement
+naiveKNNsearch = function(
+    Matrix[Double] P,
+    Matrix[Double] Q,
+    Integer K
+)return(
+    Matrix[Double] O
+){
+  num_records = nrow (P);
+  num_features = ncol (P);
+  num_queries = nrow (Q);
+  Qt = t(Q);
+  PQt = P %*% Qt;
+  P2 = rowSums (P ^ 2);
+  D = -2 * PQt + P2;
+  if (K == 1) {
+    Dt = t(D);
+    O = rowIndexMin (Dt);
+  } else {
+    O = matrix (0, rows = num_queries, cols = K);
+    parfor (i in 1:num_queries) {
+      D_sorted=order(target=D[,i], by=1, decreasing=FALSE, index.return=TRUE);
+      O[i,] = t(D_sorted[1:K,1]);
+    }
+  }
+}
+
+#naive knn search for predict value only implement
+#TODO eliminate redundancy
+naiveKNNsearchForPredict = function(
+    matrix[double] P,
+    matrix[double] Q,
+    matrix[double] L,
+    integer K
+)return(
+    matrix[double] OL
+){
+  num_records = nrow (P);
+  num_features = ncol (P);
+  num_queries = nrow (Q);
+  Qt = t(Q);
+  PQt = P %*% Qt;
+  P2 = rowSums (P ^ 2);
+  D = -2 * PQt + P2;
+  if (K == 1) {
+    Dt = t(D);
+    O = rowIndexMin (Dt);
+    OL = matrix (0, rows = num_queries, cols = 1)
+    parfor( i in 1:num_queries){
+      OL[i,] = L[as.scalar(O[i,1]),1]
+    }
+  } else {
+    OL = matrix (0, rows = num_queries, cols = K);
+    parfor (i in 1:num_queries) {
+      D_sorted=order(target=cbind(D[,i],L), by=1, decreasing=FALSE, 
index.return=FALSE);
+      OL[i,] = t(D_sorted[1:K,2]);
+    }
+  }
+}
+
+getErr_k = function (  matrix[double] in_m_neighbor_value,
+       matrix[double] in_m_cl,
+       integer in_i_cl_type ,
+       integer in_i_k_min  )
+   return (  matrix[double] out_m_err  )
+{
+  i_col = ncol( in_m_neighbor_value  );
+  i_row  = nrow( in_m_neighbor_value  );
+
+  out_m_err = matrix( 0,i_row,i_col - in_i_k_min + 1  );
+  if( in_i_cl_type == 2  ) #category
+       m_correct = in_m_neighbor_value != in_m_cl[1:i_row,];
+  else #continuous
+       m_correct = (in_m_neighbor_value - in_m_cl[1:i_row,])^2;#ppred( 
in_m_neighbor_value,in_m_cl,"-"  );
+  parfor( i in 1:i_col-in_i_k_min+1 ,check = 0 ){
+     out_m_err[,i] =
+       ( rowSums( m_correct[,1:in_i_k_min + i - 1]  ) / ( in_i_k_min + i - 1  
)  );
+  }
+  #return err for each record and each k ( belong to range 1~max  );
+}
+
+eliminateModel = function (  double s_err_mean, double s_err_vars, integer 
i_row  )
+  return(  boolean out_b_inactived   ){
+  #alpha, beta, gamma, delta
+  d_gamma = 0.001;
+  d_delta = 0.001;
+  tmp_d_delta = cdf(target = (-d_gamma - s_err_mean)/s_err_vars, 
dist="t",df=i_row-1);
+  out_b_inactived = (tmp_d_delta < d_delta)
+}
+
+getErr = function (  matrix[double] in_m_neighbor_value,
+        matrix[double] in_m_cl,
+        integer in_i_cl_type )
+    return (  matrix[double] out_m_err )
+{
+  i_col = ncol( in_m_neighbor_value );
+  i_row  = nrow( in_m_neighbor_value );
+  if( in_i_cl_type == 2 ) #category
+    m_correct = in_m_neighbor_value != in_m_cl[1:i_row,];
+  else #continuous
+    m_correct = (in_m_neighbor_value - in_m_cl[1:i_row,])^2;
+  out_m_err = ( rowSums( m_correct[,1:i_col] )/( i_col ) );
+}
+
+# getSelectedFeatureAndK:
+#   Combine k and feature selection algorithm.
+#   Refer to ADD part "8.Combined k and feature selection"
+# Argument:
+# in_m_data                     input matrix as features
+# in_m_data_target              input matrix as target value
+# in_i_is_categorical           1 = category , 0 = continuous
+# in_m_init_selected            S.user, initial selected feature which use 
specified
+# in_i_max_select               J, max feature selected
+# k_min                         minimun k
+# k_max                         maximun k
+# interval                      block size for BRACE algorithm
+#
+# Reture:
+# out_m_selected_feature        output matrix for feature selection
+# out_i_selected_k              output k value for k selection
+
+getSelectedFeatureAndK = function (
+   matrix[double] in_m_data,
+   matrix[double] in_m_data_target,
+   integer in_i_is_categorical, # 1 = category , 0 = continuous
+   matrix[double] in_m_init_selected,
+   integer in_i_max_select,
+   integer k_min,
+   integer k_max,
+   integer interval )
+return(
+   matrix[double] out_m_selected_feature,
+   integer out_i_selected_k
+    )
+{
+  m_err = matrix( 0,1,k_max-k_min+1 );
+  m_feature = matrix( 0,k_max-k_min+1,ncol( in_m_data ) );
+  #Step 1. For each k in [k_min,k_max] ( k_min has default value 1, k_max has 
default value 100 )
+  #in parallel select relevant features using FS+BRACE or schemata search 
described in Section 7.
+  parfor( i in k_min:k_max,check=0 ){
+    [m_selected_feature,d_err] =
+      getSelectedFeature( in_m_data,in_m_data_target,in_i_is_categorical,
+      in_m_init_selected,in_i_max_select,i,interval );
+    m_err[1,i] = d_err;
+    m_feature[i,] = m_selected_feature;
+  }
+  #Step 2. Output the combination of features and k with the smallest LOOCV 
error.
+  i_min_err_index = as.integer( as.scalar( rowIndexMin( m_err ) ) );
+  out_m_selected_feature = m_feature[i_min_err_index,];
+  out_i_selected_k = i_min_err_index + k_min - 1;
+}
+
+getFeatureImportance = function (
+   matrix[double] in_m_data,
+   matrix[double] in_m_data_target,
+   integer in_i_is_categorical, # 1 = category , 0 = continuous
+   integer k_value)
+return(
+   matrix[double] out_m_feature_importance,
+   matrix[double] out_m_norm_feature_importance
+    )
+{
+  n_feature = ncol(in_m_data)
+  n_record = nrow(in_m_data)
+  if(n_feature <= 1)
+    stop("can't estimate feature importance when ncol = 1")
+
+  m_err = matrix( 0,n_record,n_feature);
+  for(i_feature in 1:n_feature){
+    m_feature_select = matrix(1,1,n_feature)
+    m_feature_select[1,i_feature] = 0;
+    m_in_tmp_data = removeEmpty(target=in_m_data,margin="cols", select= 
m_feature_select)
+    m_neighbor_value = getKNeighbor( 
m_in_tmp_data,m_in_tmp_data,in_m_data_target,k_value);
+    m_tmp_err = getErr( m_neighbor_value,in_m_data_target ,in_i_is_categorical 
);
+    m_err[,i_feature] = m_tmp_err
+  }
+  out_m_feature_importance = colSums( m_err );
+  out_m_norm_feature_importance =
+    out_m_feature_importance / as.scalar(rowSums(out_m_feature_importance))
+}
+
+# prepareKNNData:
+#   Do data prepare - [-1,1] transform for continues variable
+# Argument:
+# * in_m_data                     input matrix as features
+prepareKNNData = function(matrix[double] in_m_data)
+  return(matrix[double] out_m_data)
+{
+  m_colmax = colMaxs(in_m_data);
+  m_colmin = colMins(in_m_data);
+  out_m_data = 2 * (in_m_data - m_colmin ) / ( m_colmax - m_colmin ) - 1;
+}
+
+getKNeighbor = function(matrix[double] in_m_data,
+   matrix[double] in_m_test_data,
+   matrix[double] in_m_cl,
+   integer in_i_k_max)
+ return (matrix[double] out_m_neighbor_value)
+{
+  # to naive
+  m_search_result = naiveKNNsearchForPredict(in_m_data, in_m_test_data, 
in_m_cl, in_i_k_max + 1)
+  out_m_neighbor_value = m_search_result[ , 2 : in_i_k_max + 1]
+}
+
+# getSelectedKBase:
+#   k selection algorithm with simple KNN algorithm.
+# Argument:
+# * in_m_data                     input matrix as features
+# * in_m_data_target              input matrix as target value
+# * in_i_is_categorical           1 = category , 0 = continuous
+# * k_min                         minimum k
+# * k_max                         maximum k
+# * interval                      block size
+#
+# Return:
+# * k                             output k value for k selection
+getSelectedKBase = function(matrix[double] in_m_data,
+   matrix[double] in_m_data_target,
+   integer in_i_is_categorical, # 1 = category, 0 = continuous
+   integer k_min,
+   integer k_max,
+   integer interval)
+ return (integer k)
+{
+  b_continue_loop = TRUE;
+  i_iter = 1;
+  i_record  = nrow(in_m_data);
+
+  i_active_model_number = k_max - k_min + 1;
+  m_active_flag = matrix(0, 1, i_active_model_number);
+
+  m_iter_err_sum = matrix(0, 1, k_max - k_min + 1);
+  m_iter_err_sum_squared = matrix(0, 1, k_max - k_min + 1);
+  while(b_continue_loop)
+  {
+    # 1.build k-d tree? , or use hash method
+    # 2.search data to get k_max nearest neighbor
+    i_process_item = i_iter * interval;
+    if(i_process_item >= i_record) {
+      i_process_item = i_record;
+      b_continue_loop = FALSE;
+    }
+    i_process_begin_item = ((i_iter - 1) * interval) + 1;
+    i_process_end_item = i_process_item;
+
+    m_neighbor_value = getKNeighbor(in_m_data, in_m_data[i_process_begin_item 
: i_process_end_item, ], in_m_data_target, k_max);
+    # 3.get matrix of err from k_min to k_max
+    m_err = getErr_k(m_neighbor_value, in_m_data_target[i_process_begin_item : 
i_process_end_item, ], in_i_is_categorical, k_min);
+
+    # 4.check this matrix to drop unnessary record
+    m_active_flag_tmp = matrix(0, 1, ncol(m_err));
+
+    s_rows_number = i_process_item;
+
+    m_iter_err_sum = colSums(m_err) + m_iter_err_sum;
+    m_iter_err_sum_squared = colSums(m_err ^ 2) + m_iter_err_sum_squared;
+
+    m_err_mean = - outer(t(m_iter_err_sum), m_iter_err_sum , "-") / 
s_rows_number;
+    m_err_vars = ( m_err_mean ^2 * s_rows_number -
+      2 * m_err_mean * m_iter_err_sum  + m_iter_err_sum_squared) / 
(s_rows_number-1);
+    m_err_vars = sqrt(m_err_vars);
+
+    parfor(i in 1 : ncol(m_err), check = 0) {
+      parfor(j in 1 : ncol(m_err), check = 0) {
+        b_execute_block = !(j == i
+          | as.scalar(m_active_flag_tmp[1, i]) == 1 # i has dropped, ignore 
this case
+          | as.scalar(m_active_flag_tmp[1, j]) == 1) # j has dropped, ignore 
this case
+        if(b_execute_block) {
+          b_flag = eliminateModel(as.scalar(m_err_mean[i, j]), 
as.scalar(m_err_vars[i, j]), s_rows_number);
+          if(b_flag == TRUE)
+           m_active_flag_tmp[1, i] = 1;
+        }
+      }
+    }
+
+    m_active_flag =  ((m_active_flag + m_active_flag_tmp) >= 1);
+    i_active_model_number = -sum(m_active_flag - 1);
+
+    # 5.break while check
+    if(i_active_model_number <= 1)
+      b_continue_loop = FALSE;
+
+    i_iter = i_iter + 1;
+    print("i_iter" + i_iter)
+  }
+
+  k = 0;
+  if(i_active_model_number == 0) {
+    print("All k kick out, use min of range " + k_min);
+    k = k_min;
+  }
+  else if(i_active_model_number == 1) {
+    k = k_min + as.integer(as.scalar(rowIndexMin(m_active_flag))) - 1;
+    print( "Get k, which value is " + k  );
+  }
+  else {
+    m_err_for_order =
+      cbind(t(m_iter_err_sum), matrix(seq(k_min, k_max, 1), k_max - k_min + 1, 
1));
+    m_err_for_order = removeEmpty(
+      target = m_err_for_order * t(m_active_flag == 0), margin = "rows");
+    for(i in 1 : nrow(m_err_for_order)) {
+      print("k:" + as.scalar(m_err_for_order[i, 2]) +
+        ", err:" + as.scalar(m_err_for_order[i, 1]));
+    }
+    m_err_order = order(target = m_err_for_order, by = 1, decreasing = FALSE, 
index.return = FALSE);
+    k = as.integer(as.scalar(m_err_order[1, 2]));
+    print("Get minimum LOOCV error, which value is " + k);
+  }
+}
+
+# getSelectedFeature:
+#   feature selection algorithm.
+#   Refer to ADD part "7.1 FS+BRACE"
+# Argument:
+# in_m_data                     input matrix as features
+# in_m_data_target              input matrix as target value
+# in_i_is_categorical           1 = category , 0 = continuous
+# in_m_init_selected            S.user, initial selected feature which use 
specified
+# in_i_max_select               J, max feature selected
+# k_value                       k
+# interval                      block size for BRACE algorithm
+#
+# Return:
+# out_m_selected_feature        output matrix for feature selection
+# out_d_min_LOOCV               output err
+
+getSelectedFeature = function (
+   matrix[double] in_m_data,
+   matrix[double] in_m_data_target,
+   integer in_i_is_categorical, # 1 = category , 0 = continuous
+   matrix[double] in_m_init_selected,
+   integer in_i_max_select,
+   integer k_value,
+   integer interval )
+return(
+   matrix[double] out_m_selected_feature,
+   double out_d_min_LOOCV
+    )
+{
+  i_n_record  = nrow( in_m_data );
+  i_n_column = ncol( in_m_data );
+  m_main_selected_flag = in_m_init_selected;
+  b_no_feature_selected = TRUE;
+  if( sum( in_m_init_selected ) >= 1 )
+    b_no_feature_selected = FALSE;
+
+  d_max_err_value = ( max( in_m_data_target ) - min( in_m_data_target ) ) * 
100;
+  b_continue_main_loop = TRUE;    #level 1 while loop flag
+  d_min_LOOCV = Inf;
+  while( b_continue_main_loop ){
+    m_feature_selected_flag = m_main_selected_flag;
+    m_this_model_selected_flag = TRUE;
+    i_index_min_LOOCV = -1; # flag for which model win in BRACE algorith
+    b_selected_morethan_one = FALSE;
+    b_continue_loop = TRUE; #level 2 while loop flag
+    i_iter = 1;
+    m_iter_err_sum = matrix( 0,1,i_n_column+1 );
+    m_iter_err_sum_squared = matrix( 0,1,i_n_column+1 );
+    while( b_continue_loop ){
+      i_process_item = i_iter*interval;
+      if(  i_process_item >= i_n_record ){
+        i_process_item = i_n_record;
+        b_continue_loop = FALSE;
+      }
+      i_process_begin_item = (i_iter - 1)*interval + 1
+      i_process_end_item = i_process_item
+      m_err = matrix( d_max_err_value,i_process_end_item - 
i_process_begin_item + 1,i_n_column+1 );
+      if( b_no_feature_selected == TRUE ){
+        parfor( i in 1:i_n_column ,check=0){
+          if( as.scalar( m_feature_selected_flag[1,i] ) != 1 ){
+            m_tmp_process_data = in_m_data[,i];
+            m_neighbor_value = getKNeighbor(m_tmp_process_data,
+              m_tmp_process_data[i_process_begin_item:i_process_end_item,], 
in_m_data_target,k_value );
+            m_tmp_err = getErr(m_neighbor_value,
+              in_m_data_target[i_process_begin_item:i_process_end_item,], 
in_i_is_categorical );
+            m_err[,i] = m_tmp_err;
+          }
+        }
+      }else{
+        #Use m_main_selected_flag but not m_feature_selected_flag,
+        # m_main_selected_flag: which feature are init selected
+        # m_feature_selected_flag: which feature are dropped & init selected
+        m_tmp_data = removeEmpty( target=in_m_data ,margin="cols", select = 
m_main_selected_flag);
+        if( m_this_model_selected_flag == TRUE ){
+          m_neighbor_value = getKNeighbor(
+            
m_tmp_data,m_tmp_data[i_process_begin_item:i_process_end_item,],in_m_data_target,
 k_value );
+          m_tmp_err = getErr( 
m_neighbor_value,in_m_data_target[i_process_begin_item:i_process_end_item,],in_i_is_categorical
 );
+          m_err[,i_n_column+1] = m_tmp_err;
+        }
+        parfor( i in 1:i_n_column ,check=0){
+          if( as.scalar( m_feature_selected_flag[1,i] ) != 1 ){
+            m_tmp_process_data = cbind( m_tmp_data,in_m_data[,i] );
+            m_neighbor_value = getKNeighbor(
+              
m_tmp_process_data,m_tmp_process_data[i_process_begin_item:i_process_end_item,],in_m_data_target,k_value
 );
+            m_tmp_err = getErr(
+              
m_neighbor_value,in_m_data_target[i_process_begin_item:i_process_end_item,],in_i_is_categorical
 );
+            m_err[,i] = m_tmp_err;
+          }
+        }
+      }
+      if( m_this_model_selected_flag == TRUE )
+        m_active_flag_tmp = cbind( m_feature_selected_flag,matrix( 0,1,1 ) );
+      else
+        m_active_flag_tmp = cbind( m_feature_selected_flag,matrix( 1,1,1 ) );
+      s_rows_number = i_process_item
+      m_iter_err_sum = colSums(m_err) + m_iter_err_sum
+      m_iter_err_sum_squared = colSums(m_err ^ 2) + m_iter_err_sum_squared
+      m_err_mean = - outer(t(m_iter_err_sum), m_iter_err_sum , "-") / 
s_rows_number
+      m_err_vars = ( m_err_mean ^2 * s_rows_number -
+        2 * m_err_mean * m_iter_err_sum  + m_iter_err_sum_squared) / 
(s_rows_number-1)
+      m_err_vars = sqrt(m_err_vars)
+      parfor( i in 1:ncol( m_err ) ){
+        parfor( j in 1:ncol( m_err ) ,check=0){
+          b_execute_block = TRUE;
+          if( j==i ) b_execute_block = FALSE;
+          if( as.scalar( m_active_flag_tmp[1,i] ) == 1 )  b_execute_block = 
FALSE;
+          #i has dropped, ignore this case
+          if( as.scalar( m_active_flag_tmp[1,j] ) == 1 )  b_execute_block = 
FALSE;
+          #j has dropped, ignore this case
+          if( b_execute_block ){
+            b_flag = eliminateModel( 
as.scalar(m_err_mean[i,j]),as.scalar(m_err_vars[i,j]),s_rows_number);
+            if(  b_flag == TRUE )
+              m_active_flag_tmp[1,i] = 1;
+          }
+        }
+      }
+      #We mark bit to 1 for selected feature before current loop,
+      #and mark bit to 1 also for dropped feature in current loop
+      if( sum( m_active_flag_tmp != 1 ) > 1 )
+        b_selected_morethan_one = TRUE;
+      m_col_sums_err = m_iter_err_sum #colSums( m_err );
+      i_index_min_LOOCV = as.scalar( rowIndexMin( m_col_sums_err ) );
+      d_min_LOOCV = as.scalar( m_col_sums_err[1,i_index_min_LOOCV] );
+      i_index_min_LOOCV = i_index_min_LOOCV%% ( i_n_column+1 )
+      if( sum( m_active_flag_tmp != 1 ) <= 1 )
+        b_continue_loop = FALSE;
+      if( as.scalar( m_active_flag_tmp[1,i_n_column+1] ) == 1 )
+        m_this_model_selected_flag = FALSE;           
+      m_feature_selected_flag = m_active_flag_tmp[,1:i_n_column];
+      i_iter = i_iter + 1;
+    }
+    #select current model, jump out.
+    if( i_index_min_LOOCV == 0 ){
+       b_continue_main_loop = FALSE;
+      print( "Select Current model" );
+    }else{
+      print( "select feature " + i_index_min_LOOCV + ", change bit value to 1" 
);
+      m_main_selected_flag[1,i_index_min_LOOCV] = 1;
+      b_no_feature_selected = FALSE;
+    }
+    if( sum( m_main_selected_flag - in_m_init_selected ) >= in_i_max_select ){
+      #select more than 10
+      b_continue_main_loop = FALSE;
+    }
+    if( sum( m_main_selected_flag ) == i_n_column ){
+      #all selected
+      b_continue_main_loop = FALSE;
+    }
+  }
+  out_m_selected_feature = m_main_selected_flag;
+  out_d_min_LOOCV = d_min_LOOCV;
+}
diff --git a/scripts/builtin/knnbf.dml b/scripts/builtin/knnbf.dml
new file mode 100644
index 0000000..1146680
--- /dev/null
+++ b/scripts/builtin/knnbf.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+m_knnbf = function(
+    Matrix[Double] X,
+    Matrix[Double] T,
+    Integer k_value = 5
+  ) return(
+    Matrix[Double] NNR
+  )
+{
+  num_records = nrow(X);
+  num_queries = nrow(T);
+
+  D = matrix(0, rows = num_records, cols = num_queries);
+  NNR = matrix(0, rows = num_queries, cols = k_value);
+
+  parfor(i in 1 : num_queries) {
+    D[ , i] = calculateDistance(X, T[i, ]);
+    NNR[i, ] = sortAndGetK(D[ , i], k_value);
+  }
+}
+
+calculateDistance = function(Matrix[Double] R, Matrix[Double] Q)
+  return(Matrix[Double] distances)
+{
+  NR = rowSums(R ^ 2) %*% matrix(1,1,nrow(Q));
+  NQ = matrix(1,nrow(R),1) %*% t(rowSums(Q ^ 2));
+  distances = NR + NQ - 2.0 * R %*% t(Q);
+}
+
+sortAndGetK = function(Matrix[Double] D, Integer k)
+  return (Matrix[Double] knn_)
+{
+  if(nrow(D) < k)
+    stop("can not pick "+k+" nearest neighbours from "+nrow(D)+" total 
instances")
+
+  sort_dist = order(target = D, by = 1, decreasing= FALSE, index.return =  
TRUE)
+  knn_ = t(sort_dist[1:k,])
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index ad9f141..9080136 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -33,7 +33,7 @@ import org.apache.sysds.common.Types.ReturnType;
  * builtin functions.
  *
  * To add a new builtin script function, simply add the definition here
- * as well as a dml file in scripts/builtin with a matching name. On 
+ * as well as a dml file in scripts/builtin with a matching name. On
  * building SystemDS, these scripts are packaged into the jar as well.
  */
 public enum Builtins {
@@ -136,6 +136,8 @@ public enum Builtins {
        ISINF("is.infinite", false),
        KMEANS("kmeans", true),
        KMEANSPREDICT("kmeansPredict", true),
+       KNNBF("knnbf", true),
+       KNN("knn", true),
        L2SVM("l2svm", true),
        LASSO("lasso", true),
        LENGTH("length", false),
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 0e4883c..0d541e7 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -91,7 +91,7 @@ import org.junit.Assert;
  * <li>clean up</li>
  * </ul>
  */
-public class TestUtils 
+public class TestUtils
 {
 
        private static final Log LOG = 
LogFactory.getLog(TestUtils.class.getName());
@@ -112,14 +112,14 @@ public class TestUtils
                try {
                        String lineExpected = null;
                        String lineActual = null;
-                       
+
                        Path compareFile = new Path(expectedFile);
                        FileSystem fs = 
IOUtilFunctions.getFileSystem(compareFile, conf);
                        FSDataInputStream fsin = fs.open(compareFile);
                        try( BufferedReader compareIn = new BufferedReader(new 
InputStreamReader(fsin)) ) {
                                lineExpected = compareIn.readLine();
                        }
-                       
+
                        Path outFile = new Path(actualFile);
                        FSDataInputStream fsout = fs.open(outFile);
                        try( BufferedReader outIn = new BufferedReader(new 
InputStreamReader(fsout)) ) {
@@ -132,7 +132,7 @@ public class TestUtils
                        fail("unable to read file: " + e.getMessage());
                }
        }
-       
+
        /**
         * Compares contents of an expected file with the actual file, where 
rows may be permuted
         * @param expectedFile
@@ -144,7 +144,7 @@ public class TestUtils
        {
                try {
                        HashMap<CellIndex, Double> expectedValues = new 
HashMap<>();
-                       
+
                        Path outDirectory = new Path(actualDir);
                        Path compareFile = new Path(expectedFile);
                        FileSystem fs = 
IOUtilFunctions.getFileSystem(outDirectory, conf);
@@ -166,35 +166,35 @@ public class TestUtils
                                if(expectedValue != 0.0)
                                        e_list.add(expectedValue);
                        }
-                       
+
                        ArrayList<Double> a_list = new ArrayList<>();
                        for (CellIndex index : actualValues.keySet()) {
                                Double actualValue = actualValues.get(index);
                                if(actualValue != 0.0)
                                        a_list.add(actualValue);
                        }
-                       
+
                        Collections.sort(e_list);
                        Collections.sort(a_list);
-                       
+
                        assertTrue("Matrix nzs not equal", e_list.size() == 
a_list.size());
                        for(int i=0; i < e_list.size(); i++)
                        {
                                assertTrue("Matrix values not equals", 
Math.abs(e_list.get(i) - a_list.get(i)) <= epsilon);
                        }
-                       
+
                } catch (IOException e) {
                        fail("unable to read file: " + e.getMessage());
                }
        }
-       
+
        /**
         * <p>
         * Compares the expected values calculated in Java by testcase and 
which are
         * in the normal filesystem, with those calculated by SystemDS located 
in
         * HDFS with Matrix Market format
         * </p>
-        * 
+        *
         * @param expectedFile
         *            file with expected values, which is located in OS 
filesystem
         * @param actualDir
@@ -209,33 +209,33 @@ public class TestUtils
                        Path compareFile = new Path(expectedFile);
                        FileSystem fs = 
IOUtilFunctions.getFileSystem(outDirectory, conf);
                        FSDataInputStream fsin = fs.open(compareFile);
-                       
+
                        HashMap<CellIndex, Double> expectedValues = new 
HashMap<>();
                        String[] expRcn = null;
-                       
+
                        try(BufferedReader compareIn = new BufferedReader(new 
InputStreamReader(fsin)) ) {
                                // skip the header of Matrix Market file
                                String line = compareIn.readLine();
-                               
+
                                // rows, cols and nnz
                                line = compareIn.readLine();
                                expRcn = line.split(" ");
-                               
+
                                readValuesFromFileStreamAndPut(compareIn, 
expectedValues);
                        }
-                       
+
                        HashMap<CellIndex, Double> actualValues = new 
HashMap<>();
 
                        FSDataInputStream fsout = fs.open(outDirectory);
                        try( BufferedReader outIn = new BufferedReader(new 
InputStreamReader(fsout)) ) {
-                                       
+
                                //skip MM header
                                String line = outIn.readLine();
-                               
+
                                //rows, cols and nnz
                                line = outIn.readLine();
                                String[] rcn = line.split(" ");
-                               
+
                                if (Integer.parseInt(expRcn[0]) != 
Integer.parseInt(rcn[0])) {
                                        LOG.warn(" Rows mismatch: expected " + 
Integer.parseInt(expRcn[0]) + ", actual " + Integer.parseInt(rcn[0]));
                                }
@@ -273,12 +273,12 @@ public class TestUtils
                        fail("unable to read file: " + e.getMessage());
                }
        }
-       
+
        /**
-        * Read doubles from the input stream and put them into the given 
hashmap of values. 
+        * Read doubles from the input stream and put them into the given 
hashmap of values.
         * @param inputStream input stream of doubles with related indices
         * @param values hashmap of values (initially empty)
-        * @throws IOException 
+        * @throws IOException
         */
        public static void readValuesFromFileStream(FSDataInputStream 
inputStream, HashMap<CellIndex, Double> values)
                throws IOException
@@ -293,7 +293,7 @@ public class TestUtils
         * @param inReader BufferedReader to read values from
         * @param values hashmap where values are put
        */
-       public static void readValuesFromFileStreamAndPut(BufferedReader 
inReader, HashMap<CellIndex, Double> values) 
+       public static void readValuesFromFileStreamAndPut(BufferedReader 
inReader, HashMap<CellIndex, Double> values)
                throws IOException
        {
                String line = null;
@@ -359,14 +359,14 @@ public class TestUtils
                        fail("unable to read file: " + e.getMessage());
                }
        }
-       
+
        /**
         * <p>
         * Compares the expected values calculated in Java by testcase and 
which are
         * in the normal filesystem, with those calculated by SystemDS located 
in
         * HDFS
         * </p>
-        * 
+        *
         * @param expectedFile
         *            file with expected values, which is located in OS 
filesystem
         * @param actualDir
@@ -402,7 +402,7 @@ public class TestUtils
                }
                assertEquals("for file " + actualDir + " " + countErrors + " 
values are not equal", 0, countErrors);
        }
-       
+
        /**
         * <p>
         * Compares the expected values calculated in Java by testcase and 
which are
@@ -440,7 +440,7 @@ public class TestUtils
                }
                assertEquals("for file " + actualDir + " " + countErrors + " 
values are not equal", 0, countErrors);
        }
-       
+
        public static void compareTensorBlocks(TensorBlock tb1, TensorBlock 
tb2) {
                Assert.assertEquals(tb1.getValueType(), tb2.getValueType());
                Assert.assertArrayEquals(tb1.getSchema(), tb2.getSchema());
@@ -450,12 +450,12 @@ public class TestUtils
                        for (int j = 0; j < tb1.getNumColumns(); j++)
                                Assert.assertEquals(tb1.get(new int[]{i, j}), 
tb2.get(new int[]{i, j}));
        }
-       
+
        public static TensorBlock createBasicTensor(ValueType vt, int rows, int 
cols, double sparsity) {
                return DataConverter.convertToTensorBlock(TestUtils.round(
                        MatrixBlock.randOperations(rows, cols, sparsity, 0, 10, 
"uniform", 7)), vt, true);
        }
-       
+
        public static TensorBlock createDataTensor(ValueType vt, int rows, int 
cols, double sparsity) {
                return DataConverter.convertToTensorBlock(TestUtils.round(
                        MatrixBlock.randOperations(rows, cols, sparsity, 0, 10, 
"uniform", 7)), vt, false);
@@ -470,11 +470,11 @@ public class TestUtils
         * @param filePath Path to the file to be read.
         * @return Matrix values in a hashmap <index,value>
         */
-       public static HashMap<CellIndex, Double> readDMLMatrixFromHDFS(String 
filePath) 
+       public static HashMap<CellIndex, Double> readDMLMatrixFromHDFS(String 
filePath)
        {
                HashMap<CellIndex, Double> expectedValues = new HashMap<>();
-               
-               try 
+
+               try
                {
                        Path outDirectory = new Path(filePath);
                        FileSystem fs = 
IOUtilFunctions.getFileSystem(outDirectory, conf);
@@ -484,7 +484,7 @@ public class TestUtils
                                FSDataInputStream outIn = 
fs.open(file.getPath());
                                readValuesFromFileStream(outIn, expectedValues);
                        }
-               } 
+               }
                catch (IOException e) {
                        assertTrue("could not read from file " + filePath+": 
"+e.getMessage(), false);
                }
@@ -501,26 +501,26 @@ public class TestUtils
         * @param filePath Path to the file to be read.
         * @return Matrix values in a hashmap <index,value>
         */
-       public static HashMap<CellIndex, Double> readRMatrixFromFS(String 
filePath) 
+       public static HashMap<CellIndex, Double> readRMatrixFromFS(String 
filePath)
        {
                HashMap<CellIndex, Double> expectedValues = new HashMap<>();
-               
-               try(BufferedReader reader = new BufferedReader(new 
FileReader(filePath))) 
+
+               try(BufferedReader reader = new BufferedReader(new 
FileReader(filePath)))
                {
                        // skip both R header lines
                        String line = reader.readLine();
-                       
+
                        int matrixType = -1;
                        if ( line.endsWith(" general") )
                                matrixType = 1;
                        if ( line.endsWith(" symmetric") )
                                matrixType = 2;
-                       
+
                        if ( matrixType == -1 )
                                throw new RuntimeException("unknown matrix type 
while reading R matrix: " + line);
-                       
+
                        line = reader.readLine(); // header line with dimension 
and nnz information
-                       
+
                        while ((line = reader.readLine()) != null) {
                                StringTokenizer st = new StringTokenizer(line, 
" ");
                                int i = Integer.parseInt(st.nextToken());
@@ -538,14 +538,14 @@ public class TestUtils
                                                expectedValues.put(new 
CellIndex(j, i), 1.0);
                                }
                        }
-               } 
+               }
                catch (IOException e) {
                        assertTrue("could not read from file " + filePath, 
false);
                }
-               
+
                return expectedValues;
        }
-       
+
        /**
         * Reads a scalar value in DML format from HDFS
         */
@@ -598,7 +598,7 @@ public class TestUtils
                }
                return _AssertOccured;
        }
-       
+
        public static String readDMLString(String filePath) {
                try {
                        StringBuilder sb =  new StringBuilder();
@@ -617,8 +617,8 @@ public class TestUtils
                }
                return null;
        }
-               
-       
+
+
        /**
         * Reads a scalar value in R format from OS's FS
         */
@@ -627,7 +627,7 @@ public class TestUtils
                expectedValues.put(new CellIndex(1,1), readRScalar(filePath));
                return expectedValues;
        }
-       
+
        public static Double readRScalar(String filePath) {
                try {
                        double d = Double.NaN;
@@ -643,12 +643,12 @@ public class TestUtils
                }
                return Double.NaN;
        }
-       
+
        public static String processMultiPartCSVForR(String csvFile) throws 
IOException {
                File csv = new File(csvFile);
                if (csv.isDirectory()) {
                        File[] parts = csv.listFiles();
-                       
+
                        int count=0;
                        int index = -1;
                        for(int i=0; i < parts.length; i++ ) {
@@ -659,7 +659,7 @@ public class TestUtils
                                count++;
                                index = i;
                        }
-                       
+
                        if ( count == 1) {
                                csvFile = parts[index].toString();
                        }
@@ -686,7 +686,7 @@ public class TestUtils
                                                out.append(fileContents);
                                        }
                                }
-                               
+
                                csvFile = tmp.getCanonicalPath();
                        }
                        else {
@@ -699,7 +699,7 @@ public class TestUtils
        /**
         * Compares two double values regarding tolerance t. If one or both of 
them
         * is null it is converted to 0.0.
-        * 
+        *
         * @param v1
         * @param v2
         * @param t Tolerance
@@ -722,13 +722,13 @@ public class TestUtils
 
                return Math.abs(v1 - v2) <= t;
        }
-       
+
        public static void compareMatrices(double[] expectedMatrix, double[] 
actualMatrix, double epsilon) {
-               compareMatrices(new double[][]{expectedMatrix}, 
+               compareMatrices(new double[][]{expectedMatrix},
                        new double[][]{actualMatrix}, 1, expectedMatrix.length, 
epsilon);
        }
-       
-       
+
+
        public static void compareMatrices(double[][] expectedMatrix, 
double[][] actualMatrix, int rows, int cols,
                double epsilon) {
                compareMatrices(expectedMatrix, actualMatrix, 
expectedMatrix.length, expectedMatrix[0].length, epsilon, "");
@@ -760,7 +760,7 @@ public class TestUtils
                assertEqualColsAndRows(expectedMatrix,actualMatrix);
                compareMatrices(expectedMatrix, actualMatrix, 
expectedMatrix.length, expectedMatrix[0].length, epsilon, message);
        }
-       
+
        public static void compareFrames(String[][] expectedFrame, String[][] 
actualFrame, int rows, int cols ) {
                int countErrors = 0;
                for (int i = 0; i < rows; i++) {
@@ -774,9 +774,9 @@ public class TestUtils
                }
                assertTrue("" + countErrors + " values are not in equal", 
countErrors == 0);
        }
-       
+
        public static void compareScalars(double d1, double d2, double tol) {
-               assertTrue("Given scalars do not match: " + d1 + " != " + d2 , 
compareCellValue(d1, d2, tol, false));   
+               assertTrue("Given scalars do not match: " + d1 + " != " + d2 , 
compareCellValue(d1, d2, tol, false));
        }
 
        public static void compareMatricesBit(double[][] expectedMatrix, 
double[][] actualMatrix, int rows, int cols,
@@ -796,7 +796,7 @@ public class TestUtils
        public static void compareMatricesBitAvgDistance(double[][] 
expectedMatrix, double[][] actualMatrix,
                        long maxUnitsOfLeastPrecision, long maxAvgDistance, 
String message){
                assertEqualColsAndRows(expectedMatrix,actualMatrix);
-               compareMatricesBitAvgDistance(expectedMatrix, actualMatrix, 
expectedMatrix.length, actualMatrix[0].length, 
+               compareMatricesBitAvgDistance(expectedMatrix, actualMatrix, 
expectedMatrix.length, actualMatrix[0].length,
                        maxUnitsOfLeastPrecision, maxAvgDistance, message);
        }
 
@@ -853,24 +853,24 @@ public class TestUtils
        }
 
        private static void assertEqualColsAndRows(double[][] expectedMatrix, 
double[][] actualMatrix){
-               assertTrue("The number of columns in the matrixes should be 
equal :" 
+               assertTrue("The number of columns in the matrixes should be 
equal :"
                        + expectedMatrix.length  + "  "
-                       + actualMatrix.length, 
+                       + actualMatrix.length,
                        expectedMatrix.length == actualMatrix.length);
-               assertTrue("The number of rows in the matrixes should be equal" 
-                       + expectedMatrix[0].length  + "  " 
-                       + actualMatrix[0].length, 
+               assertTrue("The number of rows in the matrixes should be equal"
+                       + expectedMatrix[0].length  + "  "
+                       + actualMatrix[0].length,
                        expectedMatrix[0].length == actualMatrix[0].length);
        }
 
-       public static void compareMatricesPercentageDistance(double[][] 
expectedMatrix, double[][] actualMatrix,  
+       public static void compareMatricesPercentageDistance(double[][] 
expectedMatrix, double[][] actualMatrix,
                        double percentDistanceAllowed, double 
maxAveragePercentDistance,  String message){
                assertEqualColsAndRows(expectedMatrix,actualMatrix);
                compareMatricesPercentageDistance(expectedMatrix, actualMatrix, 
expectedMatrix.length, expectedMatrix[0].length,
                        percentDistanceAllowed, maxAveragePercentDistance, 
message, false);
        }
 
-       public static void compareMatricesPercentageDistance(double[][] 
expectedMatrix, double[][] actualMatrix,  
+       public static void compareMatricesPercentageDistance(double[][] 
expectedMatrix, double[][] actualMatrix,
                        double percentDistanceAllowed, double 
maxAveragePercentDistance,  String message, boolean ignoreZero){
                assertEqualColsAndRows(expectedMatrix,actualMatrix);
                compareMatricesPercentageDistance(expectedMatrix, actualMatrix, 
expectedMatrix.length, expectedMatrix[0].length,
@@ -907,6 +907,29 @@ public class TestUtils
                        }
        }
 
+       public static void compareMatricesAvgRowDistance(double[][] 
expectedMatrix, double[][] actualMatrix, int rows,
+               int cols, double averageDistanceAllowed){
+                       String message = "";
+                       int countErrors = 0;
+
+                       for (int i = 0; i < rows && countErrors < 20; i++) {
+                               double distanceSum = 0;
+                               for (int j = 0; j < cols && countErrors < 20; 
j++) {
+                                       distanceSum += expectedMatrix[i][j] - 
actualMatrix[i][j];
+                               }
+                               if(distanceSum / cols > averageDistanceAllowed){
+                                       message += ("Average distance for row " 
+ i + ":" + (distanceSum / cols) + "\n");
+                                       countErrors++;
+                               }
+                       }
+                       if(countErrors == 20){
+                               assertTrue(message + "\n At least 20 values are 
not in equal", countErrors == 0);
+                       }
+                       else{
+                               assertTrue(message + "\n" + countErrors + " 
values are not in equal of total: " + (rows), countErrors == 0);
+                       }
+       }
+
        public static void compareMatricesBitAvgDistance(double[][] 
expectedMatrix, double[][] actualMatrix, int rows,
                int cols, long maxUnitsOfLeastPrecision, long maxAvgDistance) {
                        compareMatricesBitAvgDistance(expectedMatrix, 
actualMatrix, rows, cols, maxUnitsOfLeastPrecision, maxAvgDistance, "");
@@ -914,24 +937,24 @@ public class TestUtils
 
        /**
         * Compare two double precision floats for equality within a margin of 
error.
-        *  
+        *
         * This can be used to compensate for inequality caused by accumulated
         * floating point math errors.
-        * 
+        *
         * The error margin is specified in ULPs (units of least precision).
         * A one-ULP difference means there are no representable floats in 
between.
         * E.g. 0f and 1.4e-45f are one ULP apart. So are -6.1340704f and 
-6.13407f.
         * Depending on the number of calculations involved, typically a margin 
of
         * 1-5 ULPs should be enough.
-        * 
+        *
         * @param d1 The expected value.
         * @param d2 The actual value.
         * @return Whether distance in bits
         */
        public static long compareScalarBits(double d1, double d2) {
-               
+
                // assertTrue("Both values should be positive or negative",(d1 
>= 0 && d2 >= 0) || (d2 <= 0 && d1 <= 0));
-               
+
                long expectedBits = Double.doubleToLongBits(d1) < 0 ? 
0x8000000000000000L - Double.doubleToLongBits(d1) : Double.doubleToLongBits(d1);
                long actualBits = Double.doubleToLongBits(d2) < 0 ? 
0x8000000000000000L - Double.doubleToLongBits(d2) : Double.doubleToLongBits(d2);
                long difference = expectedBits > actualBits ? expectedBits - 
actualBits : actualBits - expectedBits;
@@ -954,29 +977,29 @@ public class TestUtils
                long distance = compareScalarBits(d1,d2);
                assertTrue("Given scalars do not match: " + d1 + " != " + d2 + 
" with bitDistance: " + distance ,distance <= maxUnitsOfLeastPrecision);
        }
-       
+
        public static void compareScalars(String expected, String actual) {
                        assertEquals(expected, actual);
        }
 
        public static boolean compareMatrices(HashMap<CellIndex, Double> m1, 
HashMap<CellIndex, Double> m2,
-                       double tolerance, String name1, String name2) 
+                       double tolerance, String name1, String name2)
        {
                return compareMatrices(m1, m2, tolerance, name1, name2, false);
        }
-       
+
        public static void compareMatrices(HashMap<CellIndex, Double> m1, 
MatrixBlock m2, double tolerance) {
                double[][] ret1 = convertHashMapToDoubleArray(m1);
                double[][] ret2 = DataConverter.convertToDoubleMatrix(m2);
                compareMatrices(ret1, ret2, m2.getNumRows(), 
m2.getNumColumns(), tolerance);
        }
-       
+
        public static void compareMatrices(MatrixBlock m1, MatrixBlock m2, 
double tolerance) {
                double[][] ret1 = DataConverter.convertToDoubleMatrix(m1);
                double[][] ret2 = DataConverter.convertToDoubleMatrix(m2);
                compareMatrices(ret1, ret2, m2.getNumRows(), 
m2.getNumColumns(), tolerance);
        }
-       
+
        /**
         * Compares two matrices given as HashMaps. The matrix containing more 
nnz
         * is iterated and each cell value compared against the corresponding 
cell
@@ -984,7 +1007,7 @@ public class TestUtils
         * This method does not assert. Instead statistics are added to
         * AssertionBuffer, at the end of the test you should call
         * {@link TestUtils#displayAssertionBuffer()}.
-        * 
+        *
         * @param m1
         * @param m2
         * @param tolerance
@@ -997,7 +1020,7 @@ public class TestUtils
                String namefirst = name2;
                String namesecond = name1;
                boolean flag = true;
-               
+
                // to ensure that always the matrix with more nnz is iterated
                if (m1.size() > m2.size()) {
                        first = m1;
@@ -1024,7 +1047,7 @@ public class TestUtils
                                        countErrorWithinTolerance++;
                                        if(!flag)
                                                
System.out.println(e.getKey()+": "+v1+" <--> "+v2);
-                                       else 
+                                       else
                                                
System.out.println(e.getKey()+": "+v2+" <--> "+v1);
                                }
                        } else {
@@ -1049,35 +1072,35 @@ public class TestUtils
                _AssertOccured = true;
                return false;
        }
-       
-       
+
+
        /**
-        * 
+        *
         * @param vt
         * @param in1
         * @param in2
         * @param tolerance
-        * 
+        *
         * @return
         */
        public static int compareTo(ValueType vt, Object in1, Object in2, 
double tolerance) {
                if(in1 == null && in2 == null) return 0;
                else if(in1 == null) return -1;
                else if(in2 == null) return 1;
- 
+
                switch( vt ) {
                        case STRING:  return 
((String)in1).compareTo((String)in2);
                        case BOOLEAN: return 
((Boolean)in1).compareTo((Boolean)in2);
                        case INT64:     return ((Long)in1).compareTo((Long)in2);
-                       case FP64:  
+                       case FP64:
                                return (Math.abs((Double)in1-(Double)in2) < 
tolerance)?0:
                                        ((Double)in1).compareTo((Double)in2);
                        default: throw new RuntimeException("Unsupported value 
type: "+vt);
                }
        }
-       
+
        /**
-        * 
+        *
         * @param vt
         * @param in1
         * @param inR
@@ -1087,32 +1110,32 @@ public class TestUtils
                if(in1 == null && (inR == null || 
(inR.toString().compareTo("NA")==0))) return 0;
                else if(in1 == null && vt == ValueType.STRING) return -1;
                else if(inR == null) return 1;
- 
+
                switch( vt ) {
                        case STRING:  return 
((String)in1).compareTo((String)inR);
-                       case BOOLEAN: 
+                       case BOOLEAN:
                                if(in1 == null)
                                        return 
Boolean.FALSE.compareTo(((Boolean)inR).booleanValue());
                                else
                                        return 
((Boolean)in1).compareTo((Boolean)inR);
-                       case INT64:     
+                       case INT64:
                                if(in1 == null)
                                        return new 
Long(0).compareTo(((Long)inR));
                                else
                                        return ((Long)in1).compareTo((Long)inR);
-                       case FP64:  
+                       case FP64:
                                if(in1 == null)
                                        return (new 
Double(0)).compareTo((Double)inR);
                                else
-                                       return 
(Math.abs((Double)in1-(Double)inR) < tolerance)?0:       
+                                       return 
(Math.abs((Double)in1-(Double)inR) < tolerance)?0:
                                                
((Double)in1).compareTo((Double)inR);
                        default: throw new RuntimeException("Unsupported value 
type: "+vt);
                }
        }
-       
+
        /**
         * Converts a 2D array into a sparse hashmap matrix.
-        * 
+        *
         * @param matrix
         * @return
         */
@@ -1127,7 +1150,7 @@ public class TestUtils
 
                return hmMatrix;
        }
-       
+
        /**
         * Method to convert a hashmap of matrix entries into a double array
         * @param matrix
@@ -1147,37 +1170,37 @@ public class TestUtils
                                max_cols = ci.column;
                        }
                }
-               
+
                double [][] ret_arr = new double[max_rows][max_cols];
-               
+
                for(CellIndex ci:matrix.keySet())
                {
                        int i = ci.row-1;
                        int j = ci.column-1;
                        ret_arr[i][j] = matrix.get(ci);
                }
-               
+
                return ret_arr;
-               
+
        }
-       
+
        public static double[][] convertHashMapToDoubleArray(HashMap 
<CellIndex, Double> matrix, int rows, int cols)
        {
                double [][] ret_arr = new double[rows][cols];
-               
+
                for(CellIndex ci:matrix.keySet()) {
                        int i = ci.row-1;
                        int j = ci.column-1;
                        ret_arr[i][j] = matrix.get(ci);
                }
-               
+
                return ret_arr;
-               
+
        }
 
        /**
         * Converts a 2D double array into a 1D double array.
-        * 
+        *
         * @param array
         * @return
         */
@@ -1195,7 +1218,7 @@ public class TestUtils
 
        /**
         * Converts a 1D double array into a 2D double array.
-        * 
+        *
         * @param array
         * @return
         */
@@ -1228,7 +1251,7 @@ public class TestUtils
         * Compares a dml matrix file in HDFS with a file in normal file system
         * generated by R
         * </p>
-        * 
+        *
         * @param rFile
         *            file with values calculated by R
         * @param hdfsDir
@@ -1248,7 +1271,7 @@ public class TestUtils
                                compareIn.readLine();
                                readValuesFromFileStreamAndPut(compareIn, 
expectedValues);
                        }
-                       
+
                        FileStatus[] outFiles = fs.listStatus(outDirectory);
 
                        for (FileStatus file : outFiles) {
@@ -1282,7 +1305,7 @@ public class TestUtils
         * <p>
         * Checks a matrix against a number of specifications.
         * </p>
-        * 
+        *
         * @param data
         *            matrix data
         * @param mc
@@ -1312,7 +1335,7 @@ public class TestUtils
         * Checks a matrix read from a file in text format against a number of
         * specifications.
         * </p>
-        * 
+        *
         * @param outDir
         *            directory containing the matrix
         * @param rows
@@ -1329,7 +1352,7 @@ public class TestUtils
                        Path outDirectory = new Path(outDir);
                        FileSystem fs = 
IOUtilFunctions.getFileSystem(outDirectory, conf);
                        assertTrue(outDir + " does not exist", 
fs.exists(outDirectory));
-                       
+
                        if( fs.getFileStatus(outDirectory).isDirectory() )
                        {
                                FileStatus[] outFiles = 
fs.listStatus(outDirectory);
@@ -1374,7 +1397,7 @@ public class TestUtils
         * <p>
         * Checks for matrix in directory existence.
         * </p>
-        * 
+        *
         * @param outDir
         *            directory
         */
@@ -1401,7 +1424,7 @@ public class TestUtils
         * <p>
         * Removes all the directories specified in the array in HDFS
         * </p>
-        * 
+        *
         * @param directories
         *            directories array
         */
@@ -1422,7 +1445,7 @@ public class TestUtils
         * <p>
         * Removes all the directories specified in the array in OS filesystem
         * </p>
-        * 
+        *
         * @param directories
         *            directories array
         */
@@ -1451,7 +1474,7 @@ public class TestUtils
         * <p>
         * Removes all the files specified in the array in HDFS
         * </p>
-        * 
+        *
         * @param files
         *            files array
         */
@@ -1472,7 +1495,7 @@ public class TestUtils
         * <p>
         * Removes all the files specified in the array in OS filesystem
         * </p>
-        * 
+        *
         * @param files
         *            files array
         */
@@ -1490,7 +1513,7 @@ public class TestUtils
         * <p>
         * Clears a complete directory.
         * </p>
-        * 
+        *
         * @param directory
         *            directory
         */
@@ -1514,7 +1537,7 @@ public class TestUtils
         * <p>
         * Set seed to -1 to use the current time as seed.
         * </p>
-        * 
+        *
         * @param rows
         *            number of rows
         * @param cols
@@ -1547,7 +1570,7 @@ public class TestUtils
         * Generates a test matrix with the specified parameters as a two
         * dimensional array.
         * Set seed to -1 to use the current time as seed.
-        * 
+        *
         * @param rows number of rows
         * @param cols number of columns
         * @param min minimum value
@@ -1573,9 +1596,9 @@ public class TestUtils
        }
 
        /**
-        * 
+        *
         * Generates a test matrix, but only containing real numbers, in the 
range specified.
-        * 
+        *
         * @param rows number of rows
         * @param cols number of columns
         * @param min minimum value whole number
@@ -1616,7 +1639,7 @@ public class TestUtils
         * <p>
         * Set seed to -1 to use the current time as seed.
         * </p>
-        * 
+        *
         * @param rows
         *            number of rows
         * @param cols
@@ -1653,7 +1676,7 @@ public class TestUtils
         * <p>
         * Set seed to -1 to use the current time as seed.
         * </p>
-        * 
+        *
         * @param file
         *            output file
         * @param rows
@@ -1677,7 +1700,7 @@ public class TestUtils
                        DataOutputStream out = fs.create(inFile);
                        try( PrintWriter pw = new PrintWriter(out) ) {
                                Random random = (seed == -1) ? TestUtils.random 
: new Random(seed);
-                               
+
                                for (int i = 1; i <= rows; i++) {
                                        for (int j = 1; j <= cols; j++) {
                                                if (random.nextDouble() > 
sparsity)
@@ -1880,7 +1903,7 @@ public class TestUtils
 
        /**
         * Counts the number of NNZ values in a matrix
-        * 
+        *
         * @param matrix
         * @return
         */
@@ -1895,9 +1918,9 @@ public class TestUtils
                return n;
        }
 
-       public static void writeCSVTestMatrix(String file, double[][] matrix) 
+       public static void writeCSVTestMatrix(String file, double[][] matrix)
        {
-               try 
+               try
                {
                        //create outputstream to HDFS / FS and writer
                        Path path = new Path(file);
@@ -1912,7 +1935,7 @@ public class TestUtils
                                                sb.append(matrix[i][0]);
                                        for (int j = 1; j < matrix[i].length; 
j++) {
                                                sb.append(",");
-                                               if ( matrix[i][j] == 0 ) 
+                                               if ( matrix[i][j] == 0 )
                                                        continue;
                                                sb.append(matrix[i][j]);
                                        }
@@ -1920,8 +1943,8 @@ public class TestUtils
                                        pw.append(sb.toString());
                                }
                        }
-               } 
-               catch (IOException e) 
+               }
+               catch (IOException e)
                {
                        fail("unable to write (csv) test matrix (" + file + "): 
" + e.getMessage());
                }
@@ -1931,18 +1954,18 @@ public class TestUtils
         * <p>
         * Writes a matrix to a file using the text format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param matrix
         *            matrix
         * @param isR
         *            when true, writes a R matrix to disk
-        * 
+        *
         */
-       public static void writeTestMatrix(String file, double[][] matrix, 
boolean isR) 
+       public static void writeTestMatrix(String file, double[][] matrix, 
boolean isR)
        {
-               try 
+               try
                {
                        //create outputstream to HDFS / FS and writer
                        DataOutputStream out = null;
@@ -1950,26 +1973,26 @@ public class TestUtils
                                Path path = new Path(file);
                                FileSystem fs = 
IOUtilFunctions.getFileSystem(path, conf);
                                out = fs.create(path, true);
-                       } 
+                       }
                        else {
                                out = new DataOutputStream(new 
FileOutputStream(file));
                        }
-                       
+
                        try( BufferedWriter pw = new BufferedWriter(new 
OutputStreamWriter(out))) {
-                               
+
                                //write header
                                if( isR ) {
                                        /** add R header */
                                        pw.append("%%MatrixMarket matrix 
coordinate real general\n");
                                        pw.append("" + matrix.length + " " + 
matrix[0].length + " " + matrix.length*matrix[0].length+"\n");
                                }
-                               
+
                                //writer actual matrix
                                StringBuilder sb = new StringBuilder();
                                boolean emptyOutput = true;
                                for (int i = 0; i < matrix.length; i++) {
                                        for (int j = 0; j < matrix[i].length; 
j++) {
-                                               if ( matrix[i][j] == 0 ) 
+                                               if ( matrix[i][j] == 0 )
                                                        continue;
                                                sb.append(i + 1);
                                                sb.append(' ');
@@ -1982,13 +2005,13 @@ public class TestUtils
                                                emptyOutput = false;
                                        }
                                }
-                               
+
                                //writer dummy entry if empty
                                if( emptyOutput )
                                        pw.append("1 1 " + matrix[0][0]);
                        }
-               } 
-               catch (IOException e) 
+               }
+               catch (IOException e)
                {
                        fail("unable to write test matrix (" + file + "): " + 
e.getMessage());
                }
@@ -1998,7 +2021,7 @@ public class TestUtils
         * <p>
         * Writes a matrix to a file using the text format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param matrix
@@ -2008,18 +2031,18 @@ public class TestUtils
                writeTestMatrix(file, matrix, false);
        }
 
-       
+
        /**
         * <p>
         * Writes a frame to a file using the text format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param data
         *            frame data
         * @param isR
-        * @throws IOException 
+        * @throws IOException
         */
        public static void writeTestFrame(String file, double[][] data, 
ValueType[] schema, FileFormat fmt, boolean isR) throws IOException {
                FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt);
@@ -2027,17 +2050,17 @@ public class TestUtils
                initFrameData(frame, data, schema, data.length);
                writer.writeFrameToHDFS(frame, file, data.length, 
schema.length);
        }
-       
+
        /**
         * <p>
         * Writes a frame to a file using the text format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param data
         *            frame data
-        * @throws IOException 
+        * @throws IOException
         */
        public static void writeTestFrame(String file, double[][] data, 
ValueType[] schema, FileFormat fmt) throws IOException {
                writeTestFrame(file, data, schema, fmt, false);
@@ -2047,7 +2070,7 @@ public class TestUtils
                Object[] row1 = new Object[lschema.length];
                for( int i=0; i<rows; i++ ) {
                        for( int j=0; j<lschema.length; j++ ) {
-                               data[i][j] = 
UtilFunctions.objectToDouble(lschema[j], 
+                               data[i][j] = 
UtilFunctions.objectToDouble(lschema[j],
                                                row1[j] = 
UtilFunctions.doubleToObject(lschema[j], data[i][j]));
                                if(row1[j] != null && lschema[j] == 
ValueType.STRING)
                                        row1[j] = "Str" + row1[j];
@@ -2056,7 +2079,7 @@ public class TestUtils
                }
        }
 
-       
+
        /* Write a scalar value to a file */
        public static void writeTestScalar(String file, double value) {
                try {
@@ -2079,12 +2102,12 @@ public class TestUtils
                        fail("unable to write test scalar (" + file + "): " + 
e.getMessage());
                }
        }
-       
+
        /**
         * <p>
         * Writes a matrix to a file using the binary cells format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param matrix
@@ -2125,7 +2148,7 @@ public class TestUtils
         * <p>
         * Writes a matrix to a file using the binary blocks format.
         * </p>
-        * 
+        *
         * @param file
         *            file name
         * @param matrix
@@ -2140,7 +2163,7 @@ public class TestUtils
        public static void writeBinaryTestMatrixBlocks(String file, double[][] 
matrix, int rowsInBlock, int colsInBlock,
                        boolean sparseFormat) {
                SequenceFile.Writer writer = null;
-                       
+
                try {
                        Path path = new Path(file);
                        Writer.Option filePath = Writer.file(path);
@@ -2164,7 +2187,7 @@ public class TestUtils
                                        writer.append(index, value);
                                }
                        }
-               } 
+               }
                catch (IOException e) {
                        e.printStackTrace();
                        fail("unable to write test matrix: " + e.getMessage());
@@ -2178,7 +2201,7 @@ public class TestUtils
         * <p>
         * Prints out a DML script.
         * </p>
-        * 
+        *
         * @param dmlScriptFile
         *            filename of DML script
         */
@@ -2202,7 +2225,7 @@ public class TestUtils
         * <p>
         * Prints out a PYDML script.
         * </p>
-        * 
+        *
         * @param pydmlScriptFile
         *            filename of PYDML script
         */
@@ -2214,19 +2237,19 @@ public class TestUtils
                        while ((content = in.readLine()) != null) {
                                System.out.println(content);
                        }
-               } 
+               }
                catch (IOException e) {
                        e.printStackTrace();
                        fail("unable to print pydml script: " + e.getMessage());
                }
                
System.out.println("**************************************************\n\n");
        }
-       
+
        /**
         * <p>
         * Prints out an R script.
         * </p>
-        * 
+        *
         * @param dmlScriptFile
         *            filename of RL script
         */
@@ -2250,7 +2273,7 @@ public class TestUtils
         * <p>
         * Renames a temporary DML script file back to it's original name.
         * </p>
-        * 
+        *
         * @param dmlScriptFile
         *            temporary script file
         */
@@ -2288,7 +2311,7 @@ public class TestUtils
         * Checks if any temporary files or directories exist in the current 
working
         * directory.
         * </p>
-        * 
+        *
         * @return true if temporary files or directories are available
         */
        @SuppressWarnings("resource")
@@ -2316,7 +2339,7 @@ public class TestUtils
         * Returns the path to a file in a directory if it is the only file in 
the
         * directory.
         * </p>
-        * 
+        *
         * @param directory
         *            directory containing the file
         * @return path of the file
@@ -2342,7 +2365,7 @@ public class TestUtils
         * <p>
         * Creates an empty file.
         * </p>
-        * 
+        *
         * @param filename
         *            filename
         */
@@ -2356,7 +2379,7 @@ public class TestUtils
         * <p>
         * Performs transpose onto a matrix and returns the result.
         * </p>
-        * 
+        *
         * @param a
         *            matrix
         * @return transposed matrix
@@ -2379,7 +2402,7 @@ public class TestUtils
         * <p>
         * Performs matrix multiplication onto two matrices and returns the 
result.
         * </p>
-        * 
+        *
         * @param a
         *            left matrix
         * @param b
@@ -2409,7 +2432,7 @@ public class TestUtils
         * <p>
         * Returns a random integer value.
         * </p>
-        * 
+        *
         * @return random integer value
         */
        public static int getRandomInt() {
@@ -2422,7 +2445,7 @@ public class TestUtils
         * <p>
         * Returns a positive random integer value.
         * </p>
-        * 
+        *
         * @return positive random integer value
         */
        public static int getPositiveRandomInt() {
@@ -2436,7 +2459,7 @@ public class TestUtils
         * <p>
         * Returns a negative random integer value.
         * </p>
-        * 
+        *
         * @return negative random integer value
         */
        public static int getNegativeRandomInt() {
@@ -2450,7 +2473,7 @@ public class TestUtils
         * <p>
         * Returns a random double value.
         * </p>
-        * 
+        *
         * @return random double value
         */
        public static double getRandomDouble() {
@@ -2463,7 +2486,7 @@ public class TestUtils
         * <p>
         * Returns a positive random double value.
         * </p>
-        * 
+        *
         * @return positive random double value
         */
        public static double getPositiveRandomDouble() {
@@ -2477,7 +2500,7 @@ public class TestUtils
         * <p>
         * Returns a negative random double value.
         * </p>
-        * 
+        *
         * @return negative random double value
         */
        public static double getNegativeRandomDouble() {
@@ -2492,7 +2515,7 @@ public class TestUtils
         * Returns the string representation of a double value which can be 
used in
         * a DML script.
         * </p>
-        * 
+        *
         * @param value
         *            double value
         * @return string representation
@@ -2504,7 +2527,7 @@ public class TestUtils
                nf.setMaximumFractionDigits(20);
                return nf.format(value);
        }
-       
+
        public static void replaceRandom( double[][] A, int rows, int cols, 
double replacement, int len ) {
                Random rand = new Random();
                for( int i=0; i<len; i++ )
@@ -2523,7 +2546,7 @@ public class TestUtils
         * <p>
         * Generates a matrix containing easy to debug values in its cells.
         * </p>
-        * 
+        *
         * @param rows
         * @param cols
         * @param bContainsZeros
@@ -2543,45 +2566,45 @@ public class TestUtils
                }
                return matrix;
        }
-       
+
        public static double[][] round(double[][] data) {
                for(int i=0; i<data.length; i++)
                        for(int j=0; j<data[i].length; j++)
                                data[i][j]=Math.round(data[i][j]);
                return data;
        }
-       
+
        public static double[][] round(double[][] data, int col) {
                for(int i=0; i<data.length; i++)
                        data[i][col]=Math.round(data[i][col]);
                return data;
        }
-       
+
        public static MatrixBlock round(MatrixBlock data) {
                return DataConverter.convertToMatrixBlock(
                        round(DataConverter.convertToDoubleMatrix(data)));
        }
-       
+
        public static double[][] floor(double[][] data) {
                for(int i=0; i<data.length; i++)
                        for(int j=0; j<data[i].length; j++)
                                data[i][j]=Math.floor(data[i][j]);
                return data;
        }
-       
+
        public static double[][] ceil(double[][] data) {
                for(int i=0; i<data.length; i++)
                        for(int j=0; j<data[i].length; j++)
                                data[i][j]=Math.ceil(data[i][j]);
                return data;
        }
-       
+
        public static double[][] floor(double[][] data, int col) {
                for(int i=0; i<data.length; i++)
                        data[i][col]=Math.floor(data[i][col]);
                return data;
        }
-       
+
        public static double sum(double[][] data, int rows, int cols) {
                double sum = 0;
                for (int i = 0; i< rows; i++){
@@ -2591,14 +2614,14 @@ public class TestUtils
                }
                return sum;
        }
-       
+
        public static long computeNNZ(double[][] data) {
                long nnz = 0;
                for(int i=0; i<data.length; i++)
                        nnz += UtilFunctions.computeNnz(data[i], 0, 
data[i].length);
                return nnz;
        }
-       
+
        public static double[][] seq(int from, int to, int incr) {
                int len = (int)UtilFunctions.getSeqLength(from, to, incr);
                double[][] ret = new double[len][1];
@@ -2606,7 +2629,7 @@ public class TestUtils
                        ret[i][0] = val;
                return ret;
        }
-       
+
        public static void shutdownThreads(Thread... ts) {
                for( Thread t : ts )
                        shutdownThread(t);
@@ -2616,7 +2639,7 @@ public class TestUtils
                for( Process t : ts )
                        shutdownThread(t);
        }
-       
+
        public static void shutdownThread(Thread t) {
                // kill the worker
                if( t != null ) {
@@ -2642,11 +2665,11 @@ public class TestUtils
                        }
                }
        }
-       
+
        public static String federatedAddress(int port, String input) {
                return federatedAddress("localhost", port, input);
        }
-       
+
        public static String federatedAddress(String host, int port, String 
input) {
                return host + ':' + port + '/' + input;
        }
@@ -2988,7 +3011,7 @@ public class TestUtils
                        return output;
                }
        }
-       
+
        public static double[][] generateUnbalancedGLMInputDataX(int rows, int 
cols, double logFeatureVarianceDisbalance) {
                double[][] X = generateTestMatrix(rows, cols, -1.0, 1.0, 1.0, 
34567);
                double shift_X = 1.0;
@@ -3000,14 +3023,14 @@ public class TestUtils
                }
                return X;
        }
-       
+
        public static double[] generateUnbalancedGLMInputDataB(double[][] X, 
int cols, double intercept, double avgLinearForm, double stdevLinearForm, 
Random r) {
                double[] beta_unscaled = new double[cols];
                for (int j = 0; j < cols; j++)
                        beta_unscaled[j] = r.nextGaussian();
                return scaleWeights(beta_unscaled, X, intercept, avgLinearForm, 
stdevLinearForm);
        }
-       
+
        public static double[][] generateUnbalancedGLMInputDataY(double[][] X, 
double[] beta, int rows, int cols, GLMDist glmdist, double intercept, double 
dispersion, Random r) {
                double[][] y = null;
                if (glmdist.is_binom_n_needed())
@@ -3030,7 +3053,7 @@ public class TestUtils
                                y[i][0] = glmdist.nextGLM(r, eta);
                        }
                }
-               
+
                return y;
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java
new file mode 100644
index 0000000..e62ea9f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import java.util.HashMap;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+@RunWith(value = Parameterized.class)
+public class BuiltinKNNBFTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "knnbf";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinKNNBFTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME = "B";
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int query_rows;
+       @Parameterized.Parameter(3)
+       public int query_cols;
+       @Parameterized.Parameter(4)
+       public boolean continuous;
+       @Parameterized.Parameter(5)
+       public int k_value;
+       @Parameterized.Parameter(6)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data()
+       {
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, query_rows, query_cols, continuous, 
k_value, sparsity}
+                       {150, 80, 15, 80, true, 21, 0.9}
+               });
+       }
+
+       @Test
+       public void testKNN() {
+               runKNNTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runKNNTest(ExecMode exec_mode)
+       {
+               ExecMode platform_old = setExecMode(exec_mode);
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               double[][] X = getRandomMatrix(rows, cols, 0, 1, sparsity, 255);
+               double[][] T = getRandomMatrix(query_rows, query_cols, 0, 1, 1, 
65);
+
+               double[][] CL = new double[rows][1];
+               for(int counter = 0; counter < rows; counter++)
+                       CL[counter][0] = counter + 1;
+
+               writeInputMatrixWithMTD("X", X, true);
+               writeInputMatrixWithMTD("T", T, true);
+               writeInputMatrixWithMTD("CL", CL, true);
+
+               // execute reference test
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X=" + input("X"), "in_T=" + input("T"), "in_CL=" + 
input("CL"), "in_continuous=" + (continuous ? "1" : "0"), "in_k=" + 
Integer.toString(k_value),
+                       "out_B=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // execute actual test
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X=" + input("X"), "in_T=" + input("T"), 
"in_continuous=" + (continuous ? "1" : "0"), "in_k=" + 
Integer.toString(k_value),
+                       "out_B=" + output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromExpectedDir("B");
+               HashMap<CellIndex, Double> results = 
readDMLMatrixFromOutputDir("B");
+
+               TestUtils.compareMatrices(results, refResults, 0, "Res", "Ref");
+
+               // restore execution mode
+               setExecMode(platform_old);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java
new file mode 100644
index 0000000..e2f10a3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import java.util.HashMap;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+@RunWith(value = Parameterized.class)
+public class BuiltinKNNTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "knn";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinKNNTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME_NNR = "NNR";
+       private final static String OUTPUT_NAME_PR = "PR";
+
+       private final static double TEST_TOLERANCE = 0.15;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int query_rows;
+       @Parameterized.Parameter(3)
+       public int query_cols;
+       @Parameterized.Parameter(4)
+       public boolean continuous;
+       @Parameterized.Parameter(5)
+       public int k_value;
+       @Parameterized.Parameter(6)
+       public double sparsity;
+
+       @Override
+       public void setUp()
+       {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME_NNR, 
OUTPUT_NAME_PR}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data()
+       {
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, query_rows, query_cols, continuous, 
k_value, sparsity}
+                       {100, 20, 3, 20, true, 3, 1}
+               });
+       }
+
+       @Test
+       @Ignore //TODO add libraries to docker image
+       public void testKNN() {
+               runKNNTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runKNNTest(ExecMode exec_mode)
+       {
+               ExecMode platform_old = setExecMode(exec_mode);
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // create Train and Test data
+               double[][] X = getRandomMatrix(rows, cols, 0, 1, sparsity, 75);
+               double[][] T = getRandomMatrix(query_rows, query_cols, 0, 1, 1, 
65);
+
+               double[][] CL = new double[rows][1];
+               for(int counter = 0; counter < rows; counter++)
+                       CL[counter][0] = counter + 1;
+
+               writeInputMatrixWithMTD("X", X, true);
+               writeInputMatrixWithMTD("T", T, true);
+               writeInputMatrixWithMTD("CL", CL, true);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X=" + input("X"), "in_T=" + input("T"), "in_CL=" + 
input("CL"), "in_continuous=" + (continuous ? "1" : "0"), "in_k=" + 
Integer.toString(k_value),
+                       "out_NNR=" + output(OUTPUT_NAME_NNR), "out_PR=" + 
output(OUTPUT_NAME_PR)};
+
+               fullRScriptName = HOME + TEST_NAME + ".R";
+               rCmd = getRCmd(inputDir(), (continuous ? "1" : "0"), 
Integer.toString(k_value),
+                       expectedDir());
+
+               // execute tests
+               runTest(true, false, null, -1);
+               runRScript(true);
+
+               // compare test results of RScript with dml script via files
+               HashMap<CellIndex, Double> refNNR = 
readRMatrixFromExpectedDir("NNR");
+               HashMap<CellIndex, Double> resNNR = 
readDMLMatrixFromOutputDir("NNR");
+
+               TestUtils.compareMatrices(resNNR, refNNR, 0, "ResNNR", 
"RefNNR");
+
+               double[][] refPR = 
TestUtils.convertHashMapToDoubleArray(readRMatrixFromExpectedDir("PR"));
+               double[][] resPR = 
TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromOutputDir("PR"));
+
+               TestUtils.compareMatricesAvgRowDistance(refPR, resPR, 
query_rows, query_cols, TEST_TOLERANCE);
+
+               // restore execution mode
+               setExecMode(platform_old);
+       }
+}
diff --git a/src/test/scripts/functions/builtin/knn.R 
b/src/test/scripts/functions/builtin/knn.R
new file mode 100644
index 0000000..45ba7c3
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knn.R
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# TODO arguments and order
+args <- commandArgs(TRUE)
+library("Matrix")
+
+# read test data
+data_train            <- as.matrix(readMM(paste(args[1], "/X.mtx", sep="")))
+data_test             <- as.matrix(readMM(paste(args[1], "/T.mtx", sep="")))
+CL                    <- as.matrix(readMM(paste(args[1], "/CL.mtx", sep="")))
+
+is_continuous <- as.integer(args[2])
+K <- as.integer(args[3])
+
+library(FNN);
+set.seed(10);
+tmp_data = rbind(data_train, data_test);
+knn_neighbors <- get.knn(tmp_data, k=K);
+knn_neighbors <- (tail(knn_neighbors$nn.index, NROW(data_test)));
+writeMM(as(knn_neighbors, "CsparseMatrix"), paste(args[4], "NNR", sep=""));
+
+
+# ------ training -------
+library(class)
+
+set.seed(10);
+test_pred <- knn(train=data_train, test=data_test, cl=CL, k=K);
+print("test_pred:")
+print(test_pred)
+PR_val <- matrix( , nrow=0, ncol=NCOL(data_test));
+for(i in 1:NROW(data_test)) {
+  PR_val <- rbind(PR_val, data_train[test_pred[i] , ])
+}
+writeMM(as(PR_val, "CsparseMatrix"), paste(args[4], "PR", sep=""));
diff --git a/src/test/scripts/functions/builtin/knn.dml 
b/src/test/scripts/functions/builtin/knn.dml
new file mode 100644
index 0000000..8ea5a7e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knn.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+CL = read($in_CL)
+k = $in_k
+
+[NNR, PR, FI] = knn(Train=X,  Test=T, CL=CL, k_value=k, predict_con_tg=1);
+
+PR_val = matrix(0, 0, ncol(T));
+for(i in 1:nrow(T)) {
+  PR_val = rbind(PR_val, X[as.scalar(PR[i]), ]);
+}
+
+write(NNR, $out_NNR);
+write(PR_val, $out_PR);
diff --git a/src/test/scripts/functions/builtin/knnbf.dml 
b/src/test/scripts/functions/builtin/knnbf.dml
new file mode 100644
index 0000000..e5ae2de
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knnbf.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+k = $in_k
+
+NNR = knnbf(X=X,  T=T, k_value = k)
+
+write(NNR, $out_B)
diff --git a/src/test/scripts/functions/builtin/knnbfReference.dml 
b/src/test/scripts/functions/builtin/knnbfReference.dml
new file mode 100644
index 0000000..994f466
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knnbfReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+CL = read($in_CL)
+k = $in_k
+
+[NNR, PR, FI] = knn(Train=X,  Test=T, CL=CL, k_value=k);
+
+write(NNR, $out_B)
diff --git a/src/test/scripts/installDependencies.R 
b/src/test/scripts/installDependencies.R
index b8b2e66..7ae4159 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -58,6 +58,8 @@ custom_install("mice");
 custom_install("mclust");
 custom_install("dbscan");
 custom_install("imputeTS");
+custom_install("FNN");
+custom_install("class");
 
 print("Installation Done")
 

Reply via email to