[FLINK-1745] [ml] Adjust k-nearest-neighbor-join code formatting

This closes #1220.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/035f6296
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/035f6296
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/035f6296

Branch: refs/heads/master
Commit: 035f62969523b3998b16aba9474ec4678f83b41f
Parents: 4a5af42
Author: Chiwan Park <[email protected]>
Authored: Mon May 30 17:41:32 2016 +0900
Committer: Chiwan Park <[email protected]>
Committed: Mon May 30 20:11:33 2016 +0900

----------------------------------------------------------------------
 docs/apis/batch/libs/ml/index.md                |   4 +
 docs/apis/batch/libs/ml/knn.md                  | 149 ++++++++
 docs/libs/ml/knn.md                             | 145 --------
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 134 +++----
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 311 ++++++++---------
 .../org/apache/flink/ml/nn/KNNITSuite.scala     |   3 +-
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  |  38 +-
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 345 -------------------
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 344 ------------------
 .../org/apache/flink/ml/nn/KNNITSuite.scala     |  69 ----
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  | 106 ------
 11 files changed, 378 insertions(+), 1270 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/docs/apis/batch/libs/ml/index.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/ml/index.md b/docs/apis/batch/libs/ml/index.md
index c3b6316..b956287 100644
--- a/docs/apis/batch/libs/ml/index.md
+++ b/docs/apis/batch/libs/ml/index.md
@@ -49,6 +49,10 @@ FlinkML currently supports the following algorithms:
 * [Multiple linear regression](multiple_linear_regression.html)
 * [Optimization Framework](optimization.html)
 
+### Unsupervised Learning
+
+* [k-Nearest neighbors join](knn.html)
+
 ### Data Preprocessing
 
 * [Polynomial Features](polynomial_features.html)

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/docs/apis/batch/libs/ml/knn.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/ml/knn.md b/docs/apis/batch/libs/ml/knn.md
new file mode 100644
index 0000000..294d333
--- /dev/null
+++ b/docs/apis/batch/libs/ml/knn.md
@@ -0,0 +1,149 @@
+---
+mathjax: include
+htmlTitle: FlinkML - k-Nearest neighbors join
+title: <a href="../ml">FlinkML</a> - k-Nearest neighbors join
+
+# Sub navigation
+sub-nav-group: batch
+sub-nav-parent: flinkml
+sub-nav-title: k-Nearest neighbors join
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* This will be replaced by the TOC
+{:toc}
+
+## Description
+Implements an exact k-nearest neighbors join algorithm.  Given a training set 
$A$ and a testing set $B$, the algorithm returns
+
+$$
+KNNJ(A, B, k) = \{ \left( b, KNN(b, A, k) \right) \text{ where } b \in B 
\text{ and } KNN(b, A, k) \text{ are the k-nearest points to }b\text{ in }A \}
+$$
+
+The brute-force approach is to compute the distance between every training and 
testing point. To ease the brute-force computation of computing the distance 
between every training point a quadtree is used. The quadtree scales well in 
the number of training points, though poorly in the spatial dimension. The 
algorithm will automatically choose whether or not to use the quadtree, though 
the user can override that decision by setting a parameter to force use or not 
use a quadtree. 
+
+## Operations
+
+`KNN` is a `Predictor`. 
+As such, it supports the `fit` and `predict` operation.
+
+### Fit
+
+KNN is trained by a given set of `Vector`:
+
+* `fit[T <: Vector]: DataSet[T] => Unit`
+
+### Predict
+
+KNN predicts for all subtypes of FlinkML's `Vector` the corresponding class 
label:
+
+* `predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])]`, where the 
`(T, Array[Vector])` tuple
+  corresponds to (test point, K-nearest training points)
+
+## Parameters
+
+The KNN implementation can be controlled by the following parameters:
+
+   <table class="table table-bordered">
+    <thead>
+      <tr>
+        <th class="text-left" style="width: 20%">Parameters</th>
+        <th class="text-center">Description</th>
+      </tr>
+    </thead>
+
+    <tbody>
+      <tr>
+        <td><strong>K</strong></td>
+        <td>
+          <p>
+            Defines the number of nearest-neighbors to search for. That is, 
for each test point, the algorithm finds the K-nearest neighbors in the 
training set
+            (Default value: <strong>5</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>DistanceMetric</strong></td>
+        <td>
+          <p>
+            Sets the distance metric we use to calculate the distance between 
two points. If no metric is specified, then 
[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
+            (Default value: <strong>EuclideanDistanceMetric</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Blocks</strong></td>
+        <td>
+          <p>
+            Sets the number of blocks into which the input data will be split. 
This number should be set
+            at least to the degree of parallelism. If no value is specified, 
then the parallelism of the
+            input [[DataSet]] is used as the number of blocks.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>UseQuadTree</strong></td>
+        <td>
+          <p>
+            A boolean variable that whether or not to use a quadtree to 
partition the training set to potentially simplify the KNN search. If no value 
is specified, the code will automatically decide whether or not to use a 
quadtree. Use of a quadtree scales well with the number of training and testing 
points, though poorly with the dimension.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>SizeHint</strong></td>
+        <td>
+          <p>Specifies whether the training set or test set is small to 
optimize the cross product operation needed for the KNN search. If the training 
set is small this should be `CrossHint.FIRST_IS_SMALL` and set to 
`CrossHint.SECOND_IS_SMALL` if the test set is small.
+             (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+    </tbody>
+  </table>
+
+## Examples
+
+{% highlight scala %}
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.nn.KNN
+import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
+
+val env = ExecutionEnvironment.getExecutionEnvironment
+
+// prepare data
+val trainingSet: DataSet[Vector] = ...
+val testingSet: DataSet[Vector] = ...
+
+val knn = KNN()
+  .setK(3)
+  .setBlocks(10)
+  .setDistanceMetric(SquaredEuclideanDistanceMetric())
+  .setUseQuadTree(false)
+  .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+// run knn join
+knn.fit(trainingSet)
+val result = knn.predict(testingSet).collect()
+{% endhighlight %}
+
+For more details on the computing KNN with and without and quadtree, here is a 
presentation: 
[http://danielblazevski.github.io/](http://danielblazevski.github.io/)

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/docs/libs/ml/knn.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/knn.md b/docs/libs/ml/knn.md
deleted file mode 100644
index c9a7e03..0000000
--- a/docs/libs/ml/knn.md
+++ /dev/null
@@ -1,145 +0,0 @@
----
-mathjax: include
-htmlTitle: FlinkML - k-nearest neighbors
-title: <a href="../ml">FlinkML</a> - knn
----
-<!--
-Licensed to the Apache Software Foundation (ASF) under one
-or more contributor license agreements.  See the NOTICE file
-distributed with this work for additional information
-regarding copyright ownership.  The ASF licenses this file
-to you under the Apache License, Version 2.0 (the
-"License"); you may not use this file except in compliance
-with the License.  You may obtain a copy of the License at
-
-  http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing,
-software distributed under the License is distributed on an
-"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, either express or implied.  See the License for the
-specific language governing permissions and limitations
-under the License.
--->
-
-* This will be replaced by the TOC
-{:toc}
-
-## Description
-Implements an exact k-nearest neighbors algorithm.  Given a training set $A$ 
and a testing set $B$, the algorithm returns
-
-$$
-KNN(A,B, k) = \{ \left( b, KNN(b,A, k) \right) where b \in B and KNN(b, A, k) 
are the k-nearest points to b in A \}
-$$
-
-The brute-force approach is to compute the distance between every training and 
testing point.  To ease the brute-force computation of computing the distance 
between every traning point a quadtree is used.  The quadtree scales well in 
the number of training points, though poorly in the spatial dimension.  The 
algorithm will automatically choose whether or not to use the quadtree, though 
the user can override that decision by setting a parameter to force use or not 
use a quadtree. 
-
-##Operations
-
-`KNN` is a `Predictor`. 
-As such, it supports the `fit` and `predict` operation.
-
-### Fit
-
-KNN is trained given a set of `LabeledVector`:
-
-* `fit: DataSet[LabeledVector] => Unit`
-
-### Predict
-
-KNN predicts for all subtypes of FlinkML's `Vector` the corresponding class 
label:
-
-* `predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])]`, where the 
`(T, Array[Vector])` tuple
-  corresponds to (testPoint, K-nearest training points)
-
-## Paremeters
-The KNN implementation can be controlled by the following parameters:
-
-   <table class="table table-bordered">
-    <thead>
-      <tr>
-        <th class="text-left" style="width: 20%">Parameters</th>
-        <th class="text-center">Description</th>
-      </tr>
-    </thead>
-
-    <tbody>
-      <tr>
-        <td><strong>K</strong></td>
-        <td>
-          <p>
-            Defines the number of nearest-neighbors to search for.  That is, 
for each test point, the algorithm finds the K-nearest neighbors in the 
training set
-            (Default value: <strong>5</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
-        <td><strong>DistanceMetric</strong></td>
-        <td>
-          <p>
-            Sets the distance metric we use to calculate the distance between 
two points. If no metric is specified, then 
[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
-            (Default value: <strong>EuclideanDistanceMetric</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
-        <td><strong>Blocks</strong></td>
-        <td>
-          <p>
-            Sets the number of blocks into which the input data will be split. 
This number should be set
-            at least to the degree of parallelism. If no value is specified, 
then the parallelism of the
-            input [[DataSet]] is used as the number of blocks.
-            (Default value: <strong>None</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
-        <td><strong>UseQuadTreeParam</strong></td>
-        <td>
-          <p>
-             A boolean variable that whether or not to use a Quadtree to 
partition the training set to potentially simplify the KNN search.  If no value 
is specified, the code will automatically decide whether or not to use a 
Quadtree.  Use of a Quadtree scales well with the number of training and 
testing points, though poorly with the dimension.
-            (Default value: <strong>None</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
-        <td><strong>SizeHint</strong></td>
-        <td>
-          <p>Specifies whether the training set or test set is small to 
optimize the cross product operation needed for the KNN search.  If the 
training set is small this should be `CrossHint.FIRST_IS_SMALL` and set to 
`CrossHint.SECOND_IS_SMALL` if the test set is small.
-             (Default value: <strong>None</strong>)
-          </p>
-        </td>
-      </tr>
-    </tbody>
-  </table>
-
-## Examples
-
-{% highlight scala %}
-import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
-import org.apache.flink.api.scala._
-import org.apache.flink.ml.classification.Classification
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
-
-  val env = ExecutionEnvironment.getExecutionEnvironment
-
-  // prepare data
-  val trainingSet = 
env.fromCollection(Classification.trainingData).map(_.vector)
-  val testingSet = env.fromElements(DenseVector(0.0, 0.0))
-
- val knn = KNN()
-    .setK(3)
-    .setBlocks(10)
-    .setDistanceMetric(SquaredEuclideanDistanceMetric())
-    .setUseQuadTree(false)
-    .setSizeHint(CrossHint.SECOND_IS_SMALL)
-
-  // run knn join
-  knn.fit(trainingSet)
-  val result = knn.predict(testingSet).collect()
-
-{% endhighlight %}
-
-For more details on the computing KNN with and without and quadtree, here is a 
presentation:
-http://danielblazevski.github.io/

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
index 82f4b88..d15fdaf 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
@@ -19,39 +19,38 @@
 package org.apache.flink.ml.nn
 
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.scala.utils._
 import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.utils._
 import org.apache.flink.ml.common._
-import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
-import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric, 
DistanceMetric,
-EuclideanDistanceMetric}
+import org.apache.flink.ml.math.{DenseVector, Vector => FlinkVector}
+import org.apache.flink.ml.metrics.distances._
 import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, 
Predictor}
 import org.apache.flink.util.Collector
-import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
 
 import scala.collection.immutable.Vector
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
-/** Implements a k-nearest neighbor join.
+/** Implements a `k`-nearest neighbor join.
   *
   * Calculates the `k`-nearest neighbor points in the training set for each 
point in the test set.
   *
   * @example
   * {{{
-  *         val trainingDS: DataSet[Vector] = ...
-  *         val testingDS: DataSet[Vector] = ...
+  * val trainingDS: DataSet[Vector] = ...
+  * val testingDS: DataSet[Vector] = ...
   *
-  *         val knn = KNN()
-  *           .setK(10)
-  *           .setBlocks(5)
-  *           .setDistanceMetric(EuclideanDistanceMetric())
+  * val knn = KNN()
+  *   .setK(10)
+  *   .setBlocks(5)
+  *   .setDistanceMetric(EuclideanDistanceMetric())
   *
-  *         knn.fit(trainingDS)
+  *   knn.fit(trainingDS)
   *
-  *         val predictionDS: DataSet[(Vector, Array[Vector])] = 
knn.predict(testingDS)
+  * val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS)
   * }}}
   *
   * =Parameters=
@@ -69,19 +68,19 @@ import scala.reflect.ClassTag
   * at least to the degree of parallelism. If no value is specified, then the 
parallelism of the
   * input [[DataSet]] is used as the number of blocks. (Default value: 
'''None''')
   *
-  * - [[org.apache.flink.ml.nn.KNN.UseQuadTreeParam]]
-  * A boolean variable that whether or not to use a Quadtree to partition the 
training set
+  * - [[org.apache.flink.ml.nn.KNN.UseQuadTree]]
+  * A boolean variable that whether or not to use a quadtree to partition the 
training set
   * to potentially simplify the KNN search.  If no value is specified, the 
code will
-  * automatically decide whether or not to use a Quadtree.  Use of a Quadtree 
scales well
+  * automatically decide whether or not to use a quadtree.  Use of a quadtree 
scales well
   * with the number of training and testing points, though poorly with the 
dimension.
-  * (Default value:  ```None```)
+  * (Default value: '''None''')
   *
   * - [[org.apache.flink.ml.nn.KNN.SizeHint]]
   * Specifies whether the training set or test set is small to optimize the 
cross
   * product operation needed for the KNN search.  If the training set is small
   * this should be `CrossHint.FIRST_IS_SMALL` and set to 
`CrossHint.SECOND_IS_SMALL`
   * if the test set is small.
-  * (Default value:  ```None```)
+  * (Default value: '''None''')
   *
   */
 
@@ -92,6 +91,7 @@ class KNN extends Predictor[KNN] {
   var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
 
   /** Sets K
+    *
     * @param k the number of selected points as neighbors
     */
   def setK(k: Int): KNN = {
@@ -101,6 +101,7 @@ class KNN extends Predictor[KNN] {
   }
 
   /** Sets the distance metric
+    *
     * @param metric the distance metric to calculate distance between two 
points
     */
   def setDistanceMetric(metric: DistanceMetric): KNN = {
@@ -109,6 +110,7 @@ class KNN extends Predictor[KNN] {
   }
 
   /** Sets the number of data blocks/partitions
+    *
     * @param n the number of data blocks
     */
   def setBlocks(n: Int): KNN = {
@@ -117,22 +119,19 @@ class KNN extends Predictor[KNN] {
     this
   }
 
-  /**
-    * Sets the Boolean variable that decides whether to use the QuadTree or not
-    */
+  /** Sets the Boolean variable that decides whether to use the QuadTree or 
not */
   def setUseQuadTree(useQuadTree: Boolean): KNN = {
     if (useQuadTree) {
       
require(parameters(DistanceMetric).isInstanceOf[SquaredEuclideanDistanceMetric] 
||
         parameters(DistanceMetric).isInstanceOf[EuclideanDistanceMetric])
     }
-    parameters.add(UseQuadTreeParam, useQuadTree)
+    parameters.add(UseQuadTree, useQuadTree)
     this
   }
 
-  /**
-    * Parameter a user can specify if one of the training or test sets are 
small
+  /** Parameter a user can specify if one of the training or test sets are 
small
+    *
     * @param sizeHint
-    * @return
     */
   def setSizeHint(sizeHint: CrossHint): KNN = {
     parameters.add(SizeHint, sizeHint)
@@ -155,7 +154,7 @@ object KNN {
     val defaultValue: Option[Int] = None
   }
 
-  case object UseQuadTreeParam extends Parameter[Boolean] {
+  case object UseQuadTree extends Parameter[Boolean] {
     val defaultValue: Option[Boolean] = None
   }
 
@@ -168,13 +167,15 @@ object KNN {
   }
 
   /** [[FitOperation]] which trains a KNN based on the given training data set.
+    *
     * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
     */
   implicit def fitKNN[T <: FlinkVector : TypeInformation] = new 
FitOperation[KNN, T] {
     override def fit(
-                      instance: KNN,
-                      fitParameters: ParameterMap,
-                      input: DataSet[T]): Unit = {
+      instance: KNN,
+      fitParameters: ParameterMap,
+      input: DataSet[T]
+    ): Unit = {
       val resultParameters = instance.parameters ++ fitParameters
 
       require(resultParameters.get(K).isDefined, "K is needed for calculation")
@@ -189,16 +190,17 @@ object KNN {
 
   /** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the 
given testing data
     * set.
+    *
     * @tparam T Subtype of [[Vector]]
     * @return The given testing data set with k-nearest neighbors
     */
   implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
     new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
       override def predictDataSet(
-                                   instance: KNN,
-                                   predictParameters: ParameterMap,
-                                   input: DataSet[T]): DataSet[(FlinkVector,
-        Array[FlinkVector])] = {
+        instance: KNN,
+        predictParameters: ParameterMap,
+        input: DataSet[T]
+      ): DataSet[(FlinkVector, Array[FlinkVector])] = {
         val resultParameters = instance.parameters ++ predictParameters
 
         instance.trainingSet match {
@@ -227,13 +229,17 @@ object KNN {
             val crossed = crossTuned.mapPartition {
               (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) 
=> {
                 for ((training, testing) <- iter) {
-                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
-                  // < Ntest*Ntrain, and distance is Euclidean
-                  val useQuadTree = 
resultParameters.get(UseQuadTreeParam).getOrElse(
-                    math.log(4.0) * training.values.head.size + 
math.log(math.log(training.values.length))
-                      < math.log(training.values.length) &&
-                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
-                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
+                  // use a quadtree if (4 ^ dim) * Ntest * log(Ntrain)
+                  // < Ntest * Ntrain, and distance is Euclidean
+                  val checkSize = math.log(4.0) * training.values.head.size +
+                    math.log(math.log(training.values.length)) < 
math.log(training.values.length)
+                  val checkMetric = metric match {
+                    case _: EuclideanDistanceMetric => true
+                    case _: SquaredEuclideanDistanceMetric => true
+                    case _ => false
+                  }
+                  val useQuadTree = resultParameters.get(UseQuadTree)
+                    .getOrElse(checkSize && checkMetric)
 
                   if (useQuadTree) {
                     knnQueryWithQuadTree(training.values, testing.values, k, 
metric, out)
@@ -265,19 +271,19 @@ object KNN {
             result
           case None => throw new RuntimeException("The KNN model has not been 
trained." +
             "Call first fit before calling the predict operation.")
-
         }
       }
     }
   }
 
-  def knnQueryWithQuadTree[T <: FlinkVector](
-                                              training: Vector[T],
-                                              testing: Vector[(Long, T)],
-                                              k: Int, metric: DistanceMetric,
-                                              out: Collector[(FlinkVector,
-                                                FlinkVector, Long, Double)]) {
-    /// find a bounding box
+  private def knnQueryWithQuadTree[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int,
+    metric: DistanceMetric,
+    out: Collector[(FlinkVector, FlinkVector, Long, Double)]
+  ): Unit = {
+    // find a bounding box
     val MinArr = Array.tabulate(training.head.size)(x => x)
     val MaxArr = Array.tabulate(training.head.size)(x => x)
 
@@ -289,7 +295,7 @@ object KNN {
     val MinVec = DenseVector(MinArr.map(i => math.min(minVecTrain(i), 
minVecTest(i))))
     val MaxVec = DenseVector(MinArr.map(i => math.max(maxVecTrain(i), 
maxVecTest(i))))
 
-    //default value of max elements/box is set to max(20,k)
+    // default value of max elements/box is set to max(20,k)
     val maxPerBox = math.max(k, 20)
     val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
 
@@ -301,15 +307,12 @@ object KNN {
     }
 
     for ((id, vector) <- testing) {
-      //  Find siblings' objects and do local kNN there
-      val siblingObjects =
-        trainingQuadTree.searchNeighborsSiblingQueue(vector)
+      // Find siblings' objects and do local kNN there
+      val siblingObjects = trainingQuadTree.searchNeighborsSiblingQueue(vector)
 
-      // do KNN query on siblingObjects and get max distance of kNN
-      // then rad is good choice for a neighborhood to do a refined
-      // local kNN search
-      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
-      ).sortWith(_ < _).take(k)
+      // do KNN query on siblingObjects and get max distance of kNN then rad 
is good choice
+      // for a neighborhood to do a refined local kNN search
+      val knnSiblings = siblingObjects.map(v => metric.distance(vector, 
v)).sortWith(_ < _).take(k)
 
       val rad = knnSiblings.last
       val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
@@ -321,18 +324,20 @@ object KNN {
           queue.dequeue()
         }
       }
+
       for (v <- queue) {
         out.collect(v)
       }
     }
   }
 
-  def knnQueryBasic[T <: FlinkVector](
-                                       training: Vector[T],
-                                       testing: Vector[(Long, T)],
-                                       k: Int, metric: DistanceMetric,
-                                       out: Collector[(FlinkVector, 
FlinkVector, Long, Double)]) {
-
+  private def knnQueryBasic[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int,
+    metric: DistanceMetric,
+    out: Collector[(FlinkVector, FlinkVector, Long, Double)]
+  ): Unit = {
     val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, 
Double)]()(
       Ordering.by(_._4))
     
@@ -344,6 +349,7 @@ object KNN {
           queue.dequeue()
         }
       }
+
       for (v <- queue) {
         out.collect(v)
       }

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
index d08dcdd..95a1771 100644
--- 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
+++ 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
@@ -18,30 +18,27 @@
 
 package org.apache.flink.ml.nn
 
-import org.apache.flink.ml.math.{Breeze, Vector}
-import Breeze._
+import org.apache.flink.ml.math.Breeze._
+import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.metrics.distances._
 
-import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
-EuclideanDistanceMetric, DistanceMetric}
+import scala.annotation.tailrec
+import scala.collection.mutable
 
-import scala.collection.mutable.ListBuffer
-import scala.collection.mutable.PriorityQueue
-
-/**
- * n-dimensional QuadTree data structure; partitions
- * spatial data for faster queries (e.g. KNN query)
- * The skeleton of the data structure was initially
- * based off of the 2D Quadtree found here:
- * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
- *
- * Many additional methods were added to the class both for
- * efficient KNN queries and generalizing to n-dim.
- *
- * @param minVec vector of the corner of the bounding box with smallest 
coordinates
- * @param maxVec vector of the corner of the bounding box with smallest 
coordinates
- * @param distMetric metric, must be Euclidean or squareEuclidean
- * @param maxPerBox threshold for number of points in each box before slitting 
a box
- */
+/** n-dimensional QuadTree data structure; partitions
+  * spatial data for faster queries (e.g. KNN query)
+  * The skeleton of the data structure was initially
+  * based off of the 2D Quadtree found here:
+  * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+  *
+  * Many additional methods were added to the class both for
+  * efficient KNN queries and generalizing to n-dim.
+  *
+  * @param minVec     vector of the corner of the bounding box with smallest 
coordinates
+  * @param maxVec     vector of the corner of the bounding box with smallest 
coordinates
+  * @param distMetric metric, must be Euclidean or squareEuclidean
+  * @param maxPerBox  threshold for number of points in each box before 
slitting a box
+  */
 class QuadTree(
   minVec: Vector,
   maxVec: Vector,
@@ -53,60 +50,46 @@ class QuadTree(
     width: Vector,
     var children: Seq[Node]) {
 
-    val nodeElements = new ListBuffer[Vector]
+    val nodeElements = new mutable.ListBuffer[Vector]
 
     /** for testing purposes only; used in QuadTreeSuite.scala
       *
       * @return center and width of the box
       */
-    def getCenterWidth(): (Vector, Vector) = {
-      (center, width)
-    }
+    def getCenterWidth(): (Vector, Vector) = (center, width)
 
     /** Tests whether the queryPoint is in the node, or a child of that node
       *
-      * @param queryPoint
-      * @return
+      * @param queryPoint a point to test
+      * @return whether the given point is in the node, or a child of this node
       */
-    def contains(queryPoint: Vector): Boolean = {
-      overlap(queryPoint, 0.0)
-    }
+    def contains(queryPoint: Vector): Boolean = overlap(queryPoint, 0.0)
 
     /** Tests if queryPoint is within a radius of the node
       *
-      * @param queryPoint
-      * @param radius
-      * @return
+      * @param queryPoint a point to test
+      * @param radius     radius of test area
+      * @return whether the given point is in the area
       */
-    def overlap(
-      queryPoint: Vector,
-      radius: Double): Boolean = {
-      (0 until queryPoint.size).forall{ i =>
-          (queryPoint(i) - radius < center(i) + width(i) / 2) &&
-            (queryPoint(i) + radius > center(i) - width(i) / 2)
+    def overlap(queryPoint: Vector, radius: Double): Boolean = {
+      (0 until queryPoint.size).forall { i =>
+        (queryPoint(i) - radius < center(i) + width(i) / 2) &&
+          (queryPoint(i) + radius > center(i) - width(i) / 2)
       }
     }
 
     /** Tests if queryPoint is near a node
       *
-      * @param queryPoint
-      * @param radius
-      * @return
+      * @param queryPoint a point to test
+      * @param radius     radius of covered area
       */
-    def isNear(
-      queryPoint: Vector,
-      radius: Double): Boolean = {
-      minDist(queryPoint) < radius
-    }
+    def isNear(queryPoint: Vector, radius: Double): Boolean = 
minDist(queryPoint) < radius
 
-    /**
-     * minDist is defined so that every point in the box
-     * has distance to queryPoint greater than minDist
-     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et 
al.)
-     *
-     * @param queryPoint
-     * @return
-     */
+    /** minDist is defined so that every point in the box has distance to 
queryPoint greater
+      * than minDist (minDist adopted from "Nearest Neighbors Queries" by N. 
Roussopoulos et al.)
+      *
+      * @param queryPoint
+      */
     def minDist(queryPoint: Vector): Double = {
       val minDist = (0 until queryPoint.size).map { i =>
         if (queryPoint(i) < center(i) - width(i) / 2) {
@@ -126,12 +109,12 @@ class QuadTree(
       }
     }
 
-    /**
-     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
-     * whichChild finds the appropriate index of that Seq.
-     * @param queryPoint
-     * @return
-     */
+    /** Finds which child queryPoint lies in. node.children is a Seq[Node], and
+      * [[whichChild]] finds the appropriate index of that Seq.
+      *
+      * @param queryPoint
+      * @return
+      */
     def whichChild(queryPoint: Vector): Int = {
       (0 until queryPoint.size).map { i =>
         if (queryPoint(i) > center(i)) {
@@ -152,32 +135,27 @@ class QuadTree(
       children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
     }
 
-    /**
-     * Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
-     * plane through the center of the box keeping the n-th coordinate fixed,
-     * then shifting it in the n-th direction up and down
-     * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
-     *
-     * @param center the center of the box
-     * @param width a vector of lengths of each dimension of the box
-     * @return
-     */
-    def partitionBox(
-      center: Vector,
-      width: Vector): Seq[Vector] = {
-      def partitionHelper(
-        box: Seq[Vector],
-        dim: Int): Seq[Vector] = {
+    /** Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
+      * plane through the center of the box keeping the n-th coordinate fixed,
+      * then shifting it in the n-th direction up and down
+      * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
+      *
+      * @param center the center of the box
+      * @param width  a vector of lengths of each dimension of the box
+      * @return
+      */
+    def partitionBox(center: Vector, width: Vector): Seq[Vector] = {
+      @tailrec
+      def partitionHelper(box: Seq[Vector], dim: Int): Seq[Vector] = {
         if (dim >= width.size) {
           box
         } else {
-          val newBox = box.flatMap {
-            vector =>
-              val (up, down) = (vector.copy, vector)
-              up.update(dim, up(dim) - width(dim) / 4)
-              down.update(dim, down(dim) + width(dim) / 4)
+          val newBox = box.flatMap { vector =>
+            val (up, down) = (vector.copy, vector)
+            up.update(dim, up(dim) - width(dim) / 4)
+            down.update(dim, down(dim) + width(dim) / 4)
 
-              Seq(up, down)
+            Seq(up, down)
           }
           partitionHelper(newBox, dim + 1)
         }
@@ -190,72 +168,66 @@ class QuadTree(
   val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
     (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
 
-  /**
-   * simple printing of tree for testing/debugging
-   */
+  /** Prints tree for testing/debugging */
   def printTree(): Unit = {
-    printTreeRecur(root)
-  }
-
-  def printTreeRecur(node: Node) {
-    if (node.children != null) {
-      for (c <- node.children) {
-        printTreeRecur(c)
+    def printTreeRecur(node: Node) {
+      if (node.children != null) {
+        for (c <- node.children) {
+          printTreeRecur(c)
+        }
+      } else {
+        println("printing tree: n.nodeElements " + node.nodeElements)
       }
-    } else {
-      println("printing tree: n.nodeElements " + node.nodeElements)
     }
-  }
 
-  /**
-   * Recursively adds an object to the tree
-   * @param queryPoint
-   */
-  def insert(queryPoint: Vector) {
-    insertRecur(queryPoint, root)
+    printTreeRecur(root)
   }
 
-  private def insertRecur(
-    queryPoint: Vector,
-    node: Node) {
-    if (node.children == null) {
-      if (node.nodeElements.length < maxPerBox) {
-        node.nodeElements += queryPoint
-      } else {
-        node.makeChildren()
-        for (o <- node.nodeElements) {
-          insertRecur(o, node.children(node.whichChild(o)))
+  /** Recursively adds an object to the tree
+    *
+    * @param queryPoint an object which is added
+    */
+  def insert(queryPoint: Vector) = {
+    def insertRecur(queryPoint: Vector, node: Node): Unit = {
+      if (node.children == null) {
+        if (node.nodeElements.length < maxPerBox) {
+          node.nodeElements += queryPoint
+        } else {
+          node.makeChildren()
+          for (o <- node.nodeElements) {
+            insertRecur(o, node.children(node.whichChild(o)))
+          }
+          node.nodeElements.clear()
+          insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
         }
-        node.nodeElements.clear()
+      } else {
         insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
       }
-    } else {
-      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
     }
+
+    insertRecur(queryPoint, root)
   }
 
-  /**
-   * Used to zoom in on a region near a test point for a fast KNN query.
-   * This capability is used in the KNN query to find k "near" neighbors 
n_1,...,n_k, from
-   * which one computes the max distance D_s to queryPoint.  D_s is then used 
during the
-   * kNN query to find all points within a radius D_s of queryPoint using 
searchNeighbors.
-   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of 
the leaf
-   * nodes of the minimal bounding box of the queryPoint. The priority of a 
leaf node
-   * is an appropriate notion of the distance between the test point and the 
node,
-   * which is defined by minDist(queryPoint),
-   *
-   * @param queryPoint a test point for which the method finds the minimal 
bounding
-   *                   box that queryPoint lies in and returns elements in 
that boxes
-   *                   siblings' leaf nodes
-   * @return
-   */
-  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
-    val ret = new ListBuffer[Vector]
+  /** Used to zoom in on a region near a test point for a fast KNN query.
+    * This capability is used in the KNN query to find k "near" neighbors 
n_1,...,n_k, from
+    * which one computes the max distance D_s to queryPoint.  D_s is then used 
during the
+    * kNN query to find all points within a radius D_s of queryPoint using 
searchNeighbors.
+    * To find the "near" neighbors, a min-heap is defined on the leaf nodes of 
the leaf
+    * nodes of the minimal bounding box of the queryPoint. The priority of a 
leaf node
+    * is an appropriate notion of the distance between the test point and the 
node,
+    * which is defined by minDist(queryPoint),
+    *
+    * @param queryPoint a test point for which the method finds the minimal 
bounding
+    *                   box that queryPoint lies in and returns elements in 
that boxes
+    *                   siblings' leaf nodes
+    */
+  def searchNeighborsSiblingQueue(queryPoint: Vector): 
mutable.ListBuffer[Vector] = {
+    val ret = new mutable.ListBuffer[Vector]
     // edge case when the main box has not been partitioned at all
     if (root.children == null) {
       root.nodeElements.clone()
     } else {
-      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => 
x._1))
+      val nodeQueue = new mutable.PriorityQueue[(Double, 
Node)]()(Ordering.by(x => x._1))
       searchRecurSiblingQueue(queryPoint, root, nodeQueue)
 
       var count = 0
@@ -271,16 +243,17 @@ class QuadTree(
   }
 
   /**
-   *
-   * @param queryPoint point under consideration
-   * @param node node that queryPoint lies in
-   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based 
on their
-   *                  distance to node as defined by minDist
-   */
+    *
+    * @param queryPoint point under consideration
+    * @param node       node that queryPoint lies in
+    * @param nodeQueue  defined in searchSiblingQueue, this stores nodes based 
on their
+    *                   distance to node as defined by minDist
+    */
   private def searchRecurSiblingQueue(
     queryPoint: Vector,
     node: Node,
-    nodeQueue: PriorityQueue[(Double, Node)]) {
+    nodeQueue: mutable.PriorityQueue[(Double, Node)]
+  ): Unit = {
     if (node.children != null) {
       for (child <- node.children; if child.contains(queryPoint)) {
         if (child.children == null) {
@@ -294,17 +267,18 @@ class QuadTree(
     }
   }
 
-  /**
-   * Goes down to minimal bounding box of queryPoint, and add elements to 
nodeQueue
-   *
-   * @param queryPoint point under consideration
-   * @param node node that queryPoint lies in
-   * @param nodeQueue PriorityQueue that stores all points in minimal bounding 
box of queryPoint
-   */
+  /** Goes down to minimal bounding box of queryPoint, and add elements to 
nodeQueue
+    *
+    * @param queryPoint point under consideration
+    * @param node       node that queryPoint lies in
+    * @param nodeQueue  [[mutable.PriorityQueue]] that stores all points in 
minimal bounding box
+    *                   of queryPoint
+    */
   private def minNodes(
     queryPoint: Vector,
     node: Node,
-    nodeQueue: PriorityQueue[(Double, Node)]) {
+    nodeQueue: mutable.PriorityQueue[(Double, Node)]
+  ): Unit = {
     if (node.children == null) {
       nodeQueue += ((-node.minDist(queryPoint), node))
     } else {
@@ -322,29 +296,28 @@ class QuadTree(
     * all nearby boxes. The radius is determined from 
searchNeighborsSiblingQueue
     * by defining a min-heap on the leaf nodes
     *
-    * @param queryPoint
-    * @param radius
+    * @param queryPoint a point which is center
+    * @param radius     radius of scope
     * @return all points within queryPoint with given radius
     */
-  def searchNeighbors(
-    queryPoint: Vector,
-    radius: Double): ListBuffer[Vector] = {
-    val ret = new ListBuffer[Vector]
-    searchRecur(queryPoint, radius, root, ret)
-    ret
-  }
-
-  private def searchRecur(
-    queryPoint: Vector,
-    radius: Double,
-    node: Node,
-    ret: ListBuffer[Vector]) {
-    if (node.children == null) {
-      ret ++= node.nodeElements
-    } else {
-      for (child <- node.children; if child.isNear(queryPoint, radius)) {
-        searchRecur(queryPoint, radius, child, ret)
+  def searchNeighbors(queryPoint: Vector, radius: Double): 
mutable.ListBuffer[Vector] = {
+    def searchRecur(
+      queryPoint: Vector,
+      radius: Double,
+      node: Node,
+      ret: mutable.ListBuffer[Vector]
+    ): Unit = {
+      if (node.children == null) {
+        ret ++= node.nodeElements
+      } else {
+        for (child <- node.children; if child.isNear(queryPoint, radius)) {
+          searchRecur(queryPoint, radius, child, ret)
+        }
       }
     }
+
+    val ret = new mutable.ListBuffer[Vector]
+    searchRecur(queryPoint, radius, root, ret)
+    ret
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
index 63e412a..ac30c3f 100644
--- 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
+++ 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
@@ -99,8 +99,7 @@ class KNNITSuite extends FlatSpec with Matchers with 
FlinkTestBase {
 
       // run knn join
       knn.fit(trainingSet)
-      val result = knn.predict(testingSet).collect()
-
+      knn.predict(testingSet).collect()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
index 8be5c6e..a3a415d 100644
--- 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
+++ 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
@@ -24,17 +24,18 @@ import org.apache.flink.ml.math.{Vector, DenseVector}
 
 import org.scalatest.{Matchers, FlatSpec}
 
-/** Test of Quadtree class
-  * Constructor for the Quadtree class:
-  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
+/** Tests of [[QuadTree]] class
   *
+  * Constructor for the [[QuadTree]] class:
+  * {{{
+  * class QuadTree(minVec: ListBuffer[Double], maxVec: ListBuffer[Double])
+  * }}}
   */
 
 class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
-  behavior of "The QuadTree Class"
+  behavior of "QuadTree Class"
 
   it should "partition into equal size sub-boxes and search for nearby objects 
properly" in {
-
     val minVec = DenseVector(-1.0, -0.5)
     val maxVec = DenseVector(1.0, 0.5)
 
@@ -44,26 +45,18 @@ class QuadTreeSuite extends FlatSpec with Matchers with 
FlinkTestBase {
     myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
     myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
 
-    var a = myTree.root.getCenterWidth()
-
-    /** Tree will partition once the 4th point is added
-      */
-
+    /* Tree will partition once the 4th point is added */
     myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
     myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
-
     myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
     myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
-
     myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
 
-    /**
-     * Exact values of (centers,dimensions) of root + children nodes, to test
+    /* Exact values of (centers,dimensions) of root + children nodes, to test
      * partitionBox and makeChildren methods; exact values are given to avoid
      * essentially copying and pasting the code to automatically generate them
      * from minVec/maxVec
      */
-
     val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)),
       (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
       (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
@@ -71,30 +64,23 @@ class QuadTreeSuite extends FlatSpec with Matchers with 
FlinkTestBase {
       (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
     )
 
-    /**
-     * (centers,dimensions) computed from QuadTree.makeChildren
-     */
-
+    /* (centers,dimensions) computed from QuadTree.makeChildren */
     var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)))
     for (child <- myTree.root.children) {
       computedCentersLength += 
child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
     }
 
-
-    /**
-     * Tests search for nearby neighbors, make sure the right object is 
contained in neighbor
-      * search the neighbor search will contain more points
+    /* Tests search for nearby neighbors, make sure the right object is 
contained in neighbor
+     * search the neighbor search will contain more points
      */
     val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 
0.45001), 0.001)
     val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
 
-    /**
-     * Test ability to get all objects in minimal bounding box + objects in 
siblings' block method
+    /* Test ability to get all objects in minimal bounding box + objects in 
siblings' block method
      * In this case, drawing a picture of the QuadTree shows that
      * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
      * are objects near (-0.2001, 0.31001)
      */
-
     val siblingsObjectsComputed = 
myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
     val isSiblingsInSearch = 
siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
       siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
deleted file mode 100644
index 6d563e9..0000000
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
+++ /dev/null
@@ -1,345 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.nn
-
-import org.apache.flink.api.common.operators.Order
-import org.apache.flink.api.common.typeinfo.TypeInformation
-//import org.apache.flink.api.scala.DataSetUtils._
-import org.apache.flink.api.scala.utils._
-import org.apache.flink.api.scala._
-import org.apache.flink.ml.common._
-import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
-import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
-DistanceMetric, EuclideanDistanceMetric}
-import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, 
Predictor}
-import org.apache.flink.util.Collector
-import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
-
-import org.apache.flink.ml.nn.util.QuadTree
-
-import scala.collection.immutable.Vector
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-/** Implements a k-nearest neighbor join.
-  *
-  * Calculates the `k` nearest neighbor points in the training set for each 
point in the test set.
-  *
-  * @example
-  * {{{
-  *       val trainingDS: DataSet[Vector] = ...
-  *       val testingDS: DataSet[Vector] = ...
-  *
-  *       val knn = KNN()
-  *         .setK(10)
-  *         .setBlocks(5)
-  *         .setDistanceMetric(EuclideanDistanceMetric())
-  *
-  *       knn.fit(trainingDS)
-  *
-  *       val predictionDS: DataSet[(Vector, Array[Vector])] = 
knn.predict(testingDS)
-  * }}}
-  *
-  * =Parameters=
-  *
-  * - [[org.apache.flink.ml.nn.KNN.K]]
-  * Sets the K which is the number of selected points as neighbors. (Default 
value: '''5''')
-  *
-  * - [[org.apache.flink.ml.nn.KNN.Blocks]]
-  * Sets the number of blocks into which the input data will be split. This 
number should be set
-  * at least to the degree of parallelism. If no value is specified, then the 
parallelism of the
-  * input [[DataSet]] is used as the number of blocks. (Default value: 
'''None''')
-  *
-  * - [[org.apache.flink.ml.nn.KNN.DistanceMetric]]
-  * Sets the distance metric we use to calculate the distance between two 
points. If no metric is
-  * specified, then 
[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
-  * (Default value: '''EuclideanDistanceMetric()''')
-  *
-  */
-
-class KNN extends Predictor[KNN] {
-
-  import KNN._
-
-  var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
-
-  /** Sets K
-    * @param k the number of selected points as neighbors
-    */
-  def setK(k: Int): KNN = {
-    require(k > 0, "K must be positive.")
-    parameters.add(K, k)
-    this
-  }
-
-  /** Sets the distance metric
-    * @param metric the distance metric to calculate distance between two 
points
-    */
-  def setDistanceMetric(metric: DistanceMetric): KNN = {
-    parameters.add(DistanceMetric, metric)
-    this
-  }
-
-  /** Sets the number of data blocks/partitions
-    * @param n the number of data blocks
-    */
-  def setBlocks(n: Int): KNN = {
-    require(n > 0, "Number of blocks must be positive.")
-    parameters.add(Blocks, n)
-    this
-  }
-
-  /**
-   * Sets the Boolean variable that decides whether to use the QuadTree or not
-   */
-  def setUseQuadTree(UseQuadTree: Boolean): KNN = {
-    parameters.add(UseQuadTreeParam, UseQuadTree)
-    this
-  }
-
-  /**
-   * Parameter a user can specify if one of the training or test sets are small
-   * @param sizeHint
-   * @return
-   */
-  def setSizeHint(sizeHint: CrossHint): KNN = {
-    parameters.add(SizeHint, sizeHint)
-    this
-  }
-
-}
-
-object KNN {
-
-  case object K extends Parameter[Int] {
-    val defaultValue: Option[Int] = Some(5)
-  }
-
-  case object DistanceMetric extends Parameter[DistanceMetric] {
-    val defaultValue: Option[DistanceMetric] = Some(EuclideanDistanceMetric())
-  }
-
-  case object Blocks extends Parameter[Int] {
-    val defaultValue: Option[Int] = None
-  }
-
-  case object UseQuadTreeParam extends Parameter[Boolean] {
-    val defaultValue: Option[Boolean] = None
-  }
-
-  case object SizeHint extends Parameter[CrossHint] {
-    val defaultValue: Option[CrossHint] = None
-  }
-
-  def apply(): KNN = {
-    new KNN()
-  }
-
-  /** [[FitOperation]] which trains a KNN based on the given training data set.
-    * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
-    */
-  implicit def fitKNN[T <: FlinkVector : TypeInformation] = new 
FitOperation[KNN, T] {
-    override def fit(
-      instance: KNN,
-      fitParameters: ParameterMap,
-      input: DataSet[T]): Unit = {
-      val resultParameters = instance.parameters ++ fitParameters
-
-      require(resultParameters.get(K).isDefined, "K is needed for calculation")
-
-      val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
-      val partitioner = FlinkMLTools.ModuloKeyPartitioner
-      val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
-
-      instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, 
Some(partitioner)))
-    }
-  }
-
-  /** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the 
given testing data
-    * set.
-    * @tparam T Subtype of [[Vector]]
-    * @return The given testing data set with k-nearest neighbors
-    */
-  implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
-    new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
-      override def predictDataSet(
-        instance: KNN,
-        predictParameters: ParameterMap,
-        input: DataSet[T]): DataSet[(FlinkVector,
-        Array[FlinkVector])] = {
-        val resultParameters = instance.parameters ++ predictParameters
-
-        instance.trainingSet match {
-          case Some(trainingSet) =>
-            val k = resultParameters.get(K).get
-            val blocks = 
resultParameters.get(Blocks).getOrElse(input.getParallelism)
-            val metric = resultParameters.get(DistanceMetric).get
-            val partitioner = FlinkMLTools.ModuloKeyPartitioner
-
-            // attach unique id for each data
-            val inputWithId: DataSet[(Long, T)] = input.zipWithUniqueId
-
-            // split data into multiple blocks
-            val inputSplit = FlinkMLTools.block(inputWithId, blocks, 
Some(partitioner))
-
-            val sizeHint = resultParameters.get(SizeHint)
-            val crossTuned = sizeHint match {
-              case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
-                trainingSet.crossWithHuge(inputSplit)
-              case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
-                trainingSet.crossWithTiny(inputSplit)
-              case _ => trainingSet.cross(inputSplit)
-            }
-
-            // join input and training set
-            val crossed = crossTuned.mapPartition {
-              (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) 
=> {
-                for ((training, testing) <- iter) {
-                  val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, 
Long, Double)]()(
-                    Ordering.by(_._4))
-
-                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
-                  // < Ntest*Ntrain, and distance is Euclidean
-                  val useQuadTree = 
resultParameters.get(UseQuadTreeParam).getOrElse(
-                    training.values.head.size + 
math.log(math.log(training.values.length) /
-                      math.log(4.0)) < math.log(training.values.length) / 
math.log(4.0) &&
-                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
-                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
-
-                  if (useQuadTree) {
-                    if (metric.isInstanceOf[EuclideanDistanceMetric] ||
-                      metric.isInstanceOf[SquaredEuclideanDistanceMetric]){
-                      knnQueryWithQuadTree(training.values, testing.values, k, 
metric, queue, out)
-                    } else {
-                      throw new IllegalArgumentException(s" Error: metric must 
be" +
-                        s" Euclidean or SquaredEuclidean!")
-                    }
-                  } else {
-                    knnQueryBasic(training.values, testing.values, k, metric, 
queue, out)
-                  }
-                }
-              }
-            }
-
-            // group by input vector id and pick k nearest neighbor for each 
group
-            val result = crossed.groupBy(2).sortGroup(3, 
Order.ASCENDING).reduceGroup {
-              (iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
-                if (iter.hasNext) {
-                  val head = iter.next()
-                  val key = head._2
-                  val neighbors: ArrayBuffer[FlinkVector] = 
ArrayBuffer(head._1)
-
-                  for ((vector, _, _, _) <- iter.take(k - 1)) {
-                    // we already took a first element
-                    neighbors += vector
-                  }
-
-                  out.collect(key, neighbors.toArray)
-                }
-              }
-            }
-
-            result
-          case None => throw new RuntimeException("The KNN model has not been 
trained." +
-            "Call first fit before calling the predict operation.")
-
-        }
-      }
-    }
-  }
-
-  def knnQueryWithQuadTree[T <: FlinkVector](
-    training: Vector[T],
-    testing: Vector[(Long, T)],
-    k: Int, metric: DistanceMetric,
-    queue: mutable.PriorityQueue[(FlinkVector,
-      FlinkVector, Long, Double)],
-    out: Collector[(FlinkVector,
-      FlinkVector, Long, Double)]) {
-    /// find a bounding box
-    val MinArr = Array.tabulate(training.head.size)(x => x)
-    val MaxArr = Array.tabulate(training.head.size)(x => x)
-
-    val minVecTrain = MinArr.map(i => training.map(x => x(i)).min - 0.01)
-    val minVecTest = MinArr.map(i => testing.map(x => x._2(i)).min - 0.01)
-    val maxVecTrain = MaxArr.map(i => training.map(x => x(i)).min + 0.01)
-    val maxVecTest = MaxArr.map(i => testing.map(x => x._2(i)).min + 0.01)
-
-    val MinVec = DenseVector(MinArr.map(i => Array(minVecTrain(i), 
minVecTest(i)).min))
-    val MaxVec = DenseVector(MinArr.map(i => Array(maxVecTrain(i), 
maxVecTest(i)).max))
-
-    //default value of max elements/box is set to max(20,k)
-    val maxPerBox = Array(k, 20).max
-    val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
-
-    for (v <- training) {
-      trainingQuadTree.insert(v)
-    }
-
-    for ((id, vector) <- testing) {
-      //  Find siblings' objects and do local kNN there
-      val siblingObjects =
-        trainingQuadTree.searchNeighborsSiblingQueue(vector)
-
-      // do KNN query on siblingObjects and get max distance of kNN
-      // then rad is good choice for a neighborhood to do a refined
-      // local kNN search
-      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
-      ).sortWith(_ < _).take(k)
-
-      val rad = knnSiblings.last
-      val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
-
-      for (b <- trainingFiltered) {
-        // (training vector, input vector, input key, distance)
-        queue.enqueue((b, vector, id, metric.distance(b, vector)))
-        if (queue.size > k) {
-          queue.dequeue()
-        }
-      }
-      for (v <- queue) {
-        out.collect(v)
-      }
-    }
-  }
-
-  def knnQueryBasic[T <: FlinkVector](
-    training: Vector[T],
-    testing: Vector[(Long, T)],
-    k: Int, metric: DistanceMetric,
-    queue: mutable.PriorityQueue[(FlinkVector,
-      FlinkVector, Long, Double)],
-    out: Collector[(FlinkVector, FlinkVector, Long, Double)]) {
-
-    for ((id, vector) <- testing) {
-      for (b <- training) {
-        // (training vector, input vector, input key, distance)
-        queue.enqueue((b, vector, id, metric.distance(b, vector)))
-        if (queue.size > k) {
-          queue.dequeue()
-        }
-      }
-      for (v <- queue) {
-        out.collect(v)
-      }
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
deleted file mode 100644
index 0b37313..0000000
--- 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
+++ /dev/null
@@ -1,344 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.nn.util
-
-import org.apache.flink.ml.math.{Breeze, Vector}
-import Breeze._
-
-import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
-EuclideanDistanceMetric, DistanceMetric}
-
-import scala.collection.mutable.ListBuffer
-import scala.collection.mutable.PriorityQueue
-
-/**
- * n-dimensional QuadTree data structure; partitions
- * spatial data for faster queries (e.g. KNN query)
- * The skeleton of the data structure was initially
- * based off of the 2D Quadtree found here:
- * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
- *
- * Many additional methods were added to the class both for
- * efficient KNN queries and generalizing to n-dim.
- *
- * @param minVec vector of the corner of the bounding box with smallest 
coordinates
- * @param maxVec vector of the corner of the bounding box with smallest 
coordinates
- * @param distMetric metric, must be Euclidean or squareEuclidean
- * @param maxPerBox threshold for number of points in each box before slitting 
a box
- */
-class QuadTree(
-  minVec: Vector,
-  maxVec: Vector,
-  distMetric: DistanceMetric,
-  maxPerBox: Int) {
-
-  class Node(
-    center: Vector,
-    width: Vector,
-    var children: Seq[Node]) {
-
-    val nodeElements = new ListBuffer[Vector]
-
-    /** for testing purposes only; used in QuadTreeSuite.scala
-      *
-      * @return center and width of the box
-      */
-    def getCenterWidth(): (Vector, Vector) = {
-      (center, width)
-    }
-
-    def contains(queryPoint: Vector): Boolean = {
-      overlap(queryPoint, 0.0)
-    }
-
-    /** Tests if queryPoint is within a radius of the node
-      *
-      * @param queryPoint
-      * @param radius
-      * @return
-      */
-    def overlap(
-      queryPoint: Vector,
-      radius: Double): Boolean = {
-      val count = (0 until queryPoint.size).filter { i =>
-        (queryPoint(i) - radius < center(i) + width(i) / 2) &&
-          (queryPoint(i) + radius > center(i) - width(i) / 2)
-      }.size
-
-      count == queryPoint.size
-    }
-
-    /** Tests if queryPoint is near a node
-      *
-      * @param queryPoint
-      * @param radius
-      * @return
-      */
-    def isNear(
-      queryPoint: Vector,
-      radius: Double): Boolean = {
-      minDist(queryPoint) < radius
-    }
-
-    /**
-     * minDist is defined so that every point in the box
-     * has distance to queryPoint greater than minDist
-     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et 
al.)
-     *
-     * @param queryPoint
-     * @return
-     */
-    def minDist(queryPoint: Vector): Double = {
-      val minDist = (0 until queryPoint.size).map { i =>
-        if (queryPoint(i) < center(i) - width(i) / 2) {
-          math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
-        } else if (queryPoint(i) > center(i) + width(i) / 2) {
-          math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
-        } else {
-          0
-        }
-      }.sum
-
-      distMetric match {
-        case _: SquaredEuclideanDistanceMetric => minDist
-        case _: EuclideanDistanceMetric => math.sqrt(minDist)
-        case _ => throw new IllegalArgumentException(s" Error: metric must be" 
+
-          s" Euclidean or SquaredEuclidean!")
-      }
-    }
-
-    /**
-     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
-     * whichChild finds the appropriate index of that Seq.
-     * @param queryPoint
-     * @return
-     */
-    def whichChild(queryPoint: Vector): Int = {
-      (0 until queryPoint.size).map { i =>
-        if (queryPoint(i) > center(i)) {
-          Math.pow(2, queryPoint.size - 1 - i).toInt
-        } else {
-          0
-        }
-      }.sum
-    }
-
-    def makeChildren() {
-      val centerClone = center.copy
-      val cPart = partitionBox(centerClone, width)
-      val mappedWidth = 0.5 * width.asBreeze
-      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
-    }
-
-    /**
-     * Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
-     * plane through the center of the box keeping the n-th coordinate fixed,
-     * then shifting it in the n-th direction up and down
-     * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
-     *
-     * @param center the center of the box
-     * @param width a vector of lengths of each dimension of the box
-     * @return
-     */
-    def partitionBox(
-      center: Vector,
-      width: Vector): Seq[Vector] = {
-      def partitionHelper(
-        box: Seq[Vector],
-        dim: Int): Seq[Vector] = {
-        if (dim >= width.size) {
-          box
-        } else {
-          val newBox = box.flatMap {
-            vector =>
-              val (up, down) = (vector.copy, vector)
-              up.update(dim, up(dim) - width(dim) / 4)
-              down.update(dim, down(dim) + width(dim) / 4)
-
-              Seq(up, down)
-          }
-          partitionHelper(newBox, dim + 1)
-        }
-      }
-      partitionHelper(Seq(center), 0)
-    }
-  }
-
-
-  val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
-    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
-
-  /**
-   * simple printing of tree for testing/debugging
-   */
-  def printTree(): Unit = {
-    printTreeRecur(root)
-  }
-
-  def printTreeRecur(node: Node) {
-    if (node.children != null) {
-      for (c <- node.children) {
-        printTreeRecur(c)
-      }
-    } else {
-      println("printing tree: n.nodeElements " + node.nodeElements)
-    }
-  }
-
-  /**
-   * Recursively adds an object to the tree
-   * @param queryPoint
-   */
-  def insert(queryPoint: Vector) {
-    insertRecur(queryPoint, root)
-  }
-
-  private def insertRecur(
-    queryPoint: Vector,
-    node: Node) {
-    if (node.children == null) {
-      if (node.nodeElements.length < maxPerBox) {
-        node.nodeElements += queryPoint
-      } else {
-        node.makeChildren()
-        for (o <- node.nodeElements) {
-          insertRecur(o, node.children(node.whichChild(o)))
-        }
-        node.nodeElements.clear()
-        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
-      }
-    } else {
-      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
-    }
-  }
-
-  /**
-   * Used to zoom in on a region near a test point for a fast KNN query.
-   * This capability is used in the KNN query to find k "near" neighbors 
n_1,...,n_k, from
-   * which one computes the max distance D_s to queryPoint.  D_s is then used 
during the
-   * kNN query to find all points within a radius D_s of queryPoint using 
searchNeighbors.
-   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of 
the leaf
-   * nodes of the minimal bounding box of the queryPoint. The priority of a 
leaf node
-   * is an appropriate notion of the distance between the test point and the 
node,
-   * which is defined by minDist(queryPoint),
-   *
-   * @param queryPoint a test point for which the method finds the minimal 
bounding
-   *                   box that queryPoint lies in and returns elements in 
that boxes
-   *                   siblings' leaf nodes
-   * @return
-   */
-  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
-    val ret = new ListBuffer[Vector]
-    // edge case when the main box has not been partitioned at all
-    if (root.children == null) {
-      root.nodeElements.clone()
-    } else {
-      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => 
x._1))
-      searchRecurSiblingQueue(queryPoint, root, nodeQueue)
-
-      var count = 0
-      while (count < maxPerBox) {
-        val dq = nodeQueue.dequeue()
-        if (dq._2.nodeElements.nonEmpty) {
-          ret ++= dq._2.nodeElements
-          count += dq._2.nodeElements.length
-        }
-      }
-      ret
-    }
-  }
-
-  /**
-   *
-   * @param queryPoint point under consideration
-   * @param node node that queryPoint lies in
-   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based 
on their
-   *                  distance to node as defined by minDist
-   */
-  private def searchRecurSiblingQueue(
-    queryPoint: Vector,
-    node: Node,
-    nodeQueue: PriorityQueue[(Double, Node)]) {
-    if (node.children != null) {
-      for (child <- node.children; if child.contains(queryPoint)) {
-        if (child.children == null) {
-          for (c <- node.children) {
-            minNodes(queryPoint, c, nodeQueue)
-          }
-        } else {
-          searchRecurSiblingQueue(queryPoint, child, nodeQueue)
-        }
-      }
-    }
-  }
-
-  /**
-   * Goes down to minimal bounding box of queryPoint, and add elements to 
nodeQueue
-   *
-   * @param queryPoint point under consideration
-   * @param node node that queryPoint lies in
-   * @param nodeQueue PriorityQueue that stores all points in minimal bounding 
box of queryPoint
-   */
-  private def minNodes(
-    queryPoint: Vector,
-    node: Node,
-    nodeQueue: PriorityQueue[(Double, Node)]) {
-    if (node.children == null) {
-      nodeQueue += ((-node.minDist(queryPoint), node))
-    } else {
-      for (c <- node.children) {
-        minNodes(queryPoint, c, nodeQueue)
-      }
-    }
-  }
-
-  /** Finds all objects within a neigiborhood of queryPoint of a specified 
radius
-    * scope is modified from original 2D version in:
-    * 
http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
-    *
-    * original version only looks in minimal box; for the KNN Query, we look at
-    * all nearby boxes. The radius is determined from 
searchNeighborsSiblingQueue
-    * by defining a min-heap on the leaf nodes
-    *
-    * @param queryPoint
-    * @param radius
-    * @return all points within queryPoint with given radius
-    */
-  def searchNeighbors(
-    queryPoint: Vector,
-    radius: Double): ListBuffer[Vector] = {
-    val ret = new ListBuffer[Vector]
-    searchRecur(queryPoint, radius, root, ret)
-    ret
-  }
-
-  private def searchRecur(
-    queryPoint: Vector,
-    radius: Double,
-    node: Node,
-    ret: ListBuffer[Vector]) {
-    if (node.children == null) {
-      ret ++= node.nodeElements
-    } else {
-      for (child <- node.children; if child.isNear(queryPoint, radius)) {
-        searchRecur(queryPoint, radius, child, ret)
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
deleted file mode 100644
index 350af95..0000000
--- 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.nn
-
-import org.apache.flink.api.scala._
-import org.apache.flink.ml.classification.Classification
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
-import org.apache.flink.test.util.FlinkTestBase
-import org.scalatest.{FlatSpec, Matchers}
-
-class KNNITSuite extends FlatSpec with Matchers with FlinkTestBase {
-  behavior of "The KNN Join Implementation"
-
-  it should "throw an exception when the given K is not valid" in {
-    intercept[IllegalArgumentException] {
-      KNN().setK(0)
-    }
-  }
-
-  it should "throw an exception when the given count of blocks is not valid" 
in {
-    intercept[IllegalArgumentException] {
-      KNN().setBlocks(0)
-    }
-  }
-
-  it should "calculate kNN join correctly" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    // prepare data
-    val trainingSet = 
env.fromCollection(Classification.trainingData).map(_.vector)
-    val testingSet = env.fromElements(DenseVector(0.0, 0.0))
-
-    // calculate answer
-    val answer = Classification.trainingData.map {
-      v => (v.vector, 
SquaredEuclideanDistanceMetric().distance(DenseVector(0.0, 0.0), v.vector))
-    }.sortBy(_._2).take(3).map(_._1).toArray
-
-    val knn = KNN()
-      .setK(3)
-      .setBlocks(10)
-      .setDistanceMetric(SquaredEuclideanDistanceMetric())
-      .setUseQuadTree(true)
-
-    // run knn join
-    knn.fit(trainingSet)
-    val result = knn.predict(testingSet).collect()
-
-    result.size should be(1)
-    result.head._1 should be(DenseVector(0.0, 0.0))
-    result.head._2 should be(answer)
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/035f6296/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
deleted file mode 100644
index 9b84a80..0000000
--- 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
-import org.apache.flink.ml.nn.util.QuadTree
-import org.apache.flink.test.util.FlinkTestBase
-import org.apache.flink.ml.math.{Breeze, Vector, DenseVector}
-
-import org.scalatest.{Matchers, FlatSpec}
-
-/** Test of Quadtree class
-  * Constructor for the Quadtree class:
-  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
-  *
-  */
-
-class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
-  behavior of "The QuadTree Class"
-
-  it should "partition into equal size sub-boxes and search for nearby objects 
properly" in {
-
-    val minVec = DenseVector(-1.0, -0.5)
-    val maxVec = DenseVector(1.0, 0.5)
-
-    val myTree = new QuadTree(minVec, maxVec, EuclideanDistanceMetric(), 3)
-
-    myTree.insert(DenseVector(-0.25, 0.3).asInstanceOf[Vector])
-    myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
-    myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
-
-    var a = myTree.root.getCenterWidth()
-
-    /** Tree will partition once the 4th point is added
-      */
-
-    myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
-    myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
-
-    myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
-    myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
-
-    myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
-
-    /**
-     * Exact values of (centers,dimensions) of root + children nodes, to test
-     * partitionBox and makeChildren methods; exact values are given to avoid
-     * essentially copying and pasting the code to automatically generate them
-     * from minVec/maxVec
-     */
-
-    val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)),
-      (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
-      (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
-      (DenseVector(0.5, -0.25), DenseVector(1.0, 0.5)),
-      (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
-    )
-
-    /**
-     * (centers,dimensions) computed from QuadTree.makeChildren
-     */
-
-    var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)))
-    for (child <- myTree.root.children) {
-      computedCentersLength += 
child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
-    }
-
-
-    /**
-     * Tests search for nearby neighbors, make sure the right object is 
contained in neighbor search
-     * the neighbor search will contain more points
-     */
-    val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 
0.45001), 0.001)
-    val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
-
-    /**
-     * Test ability to get all objects in minimal bounding box + objects in 
siblings' block method
-     * In this case, drawing a picture of the QuadTree shows that
-     * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
-     * are objects near (-0.2001, 0.31001)
-     */
-
-    val siblingsObjectsComputed = 
myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
-    val isSiblingsInSearch = 
siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
-      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&
-      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.289))
-
-    computedCentersLength should be(knownCentersLengths)
-    isNeighborInSearch should be(true)
-    isSiblingsInSearch should be(true)
-  }
-}

Reply via email to