Repository: flink
Updated Branches:
  refs/heads/master 1212b6d3f -> 035f62969


[FLINK-1745] [ml] Use QuadTree to speed up exact k-nearest-neighbor join


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4a5af42c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4a5af42c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4a5af42c

Branch: refs/heads/master
Commit: 4a5af42c678a0437aa5614741280e8c5465b8cec
Parents: 858ca14
Author: danielblazevski <[email protected]>
Authored: Tue Sep 15 17:49:05 2015 -0400
Committer: Chiwan Park <[email protected]>
Committed: Mon May 30 19:32:26 2016 +0900

----------------------------------------------------------------------
 docs/libs/ml/knn.md                             | 145 ++++++++
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 353 +++++++++++++++++++
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 350 ++++++++++++++++++
 .../org/apache/flink/ml/nn/KNNITSuite.scala     | 108 ++++++
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  | 107 ++++++
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 207 +++++++++--
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 344 ++++++++++++++++++
 .../org/apache/flink/ml/nn/KNNITSuite.scala     |   7 +-
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  | 106 ++++++
 9 files changed, 1686 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/docs/libs/ml/knn.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/knn.md b/docs/libs/ml/knn.md
new file mode 100644
index 0000000..c9a7e03
--- /dev/null
+++ b/docs/libs/ml/knn.md
@@ -0,0 +1,145 @@
+---
+mathjax: include
+htmlTitle: FlinkML - k-nearest neighbors
+title: <a href="../ml">FlinkML</a> - knn
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* This will be replaced by the TOC
+{:toc}
+
+## Description
+Implements an exact k-nearest neighbors algorithm.  Given a training set $A$ 
and a testing set $B$, the algorithm returns
+
+$$
+KNN(A,B, k) = \{ \left( b, KNN(b,A, k) \right) where b \in B and KNN(b, A, k) 
are the k-nearest points to b in A \}
+$$
+
+The brute-force approach is to compute the distance between every training and 
testing point.  To ease the brute-force computation of computing the distance 
between every traning point a quadtree is used.  The quadtree scales well in 
the number of training points, though poorly in the spatial dimension.  The 
algorithm will automatically choose whether or not to use the quadtree, though 
the user can override that decision by setting a parameter to force use or not 
use a quadtree. 
+
+##Operations
+
+`KNN` is a `Predictor`. 
+As such, it supports the `fit` and `predict` operation.
+
+### Fit
+
+KNN is trained given a set of `LabeledVector`:
+
+* `fit: DataSet[LabeledVector] => Unit`
+
+### Predict
+
+KNN predicts for all subtypes of FlinkML's `Vector` the corresponding class 
label:
+
+* `predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])]`, where the 
`(T, Array[Vector])` tuple
+  corresponds to (testPoint, K-nearest training points)
+
+## Paremeters
+The KNN implementation can be controlled by the following parameters:
+
+   <table class="table table-bordered">
+    <thead>
+      <tr>
+        <th class="text-left" style="width: 20%">Parameters</th>
+        <th class="text-center">Description</th>
+      </tr>
+    </thead>
+
+    <tbody>
+      <tr>
+        <td><strong>K</strong></td>
+        <td>
+          <p>
+            Defines the number of nearest-neighbors to search for.  That is, 
for each test point, the algorithm finds the K-nearest neighbors in the 
training set
+            (Default value: <strong>5</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>DistanceMetric</strong></td>
+        <td>
+          <p>
+            Sets the distance metric we use to calculate the distance between 
two points. If no metric is specified, then 
[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
+            (Default value: <strong>EuclideanDistanceMetric</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Blocks</strong></td>
+        <td>
+          <p>
+            Sets the number of blocks into which the input data will be split. 
This number should be set
+            at least to the degree of parallelism. If no value is specified, 
then the parallelism of the
+            input [[DataSet]] is used as the number of blocks.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>UseQuadTreeParam</strong></td>
+        <td>
+          <p>
+             A boolean variable that whether or not to use a Quadtree to 
partition the training set to potentially simplify the KNN search.  If no value 
is specified, the code will automatically decide whether or not to use a 
Quadtree.  Use of a Quadtree scales well with the number of training and 
testing points, though poorly with the dimension.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>SizeHint</strong></td>
+        <td>
+          <p>Specifies whether the training set or test set is small to 
optimize the cross product operation needed for the KNN search.  If the 
training set is small this should be `CrossHint.FIRST_IS_SMALL` and set to 
`CrossHint.SECOND_IS_SMALL` if the test set is small.
+             (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+    </tbody>
+  </table>
+
+## Examples
+
+{% highlight scala %}
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.classification.Classification
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
+
+  val env = ExecutionEnvironment.getExecutionEnvironment
+
+  // prepare data
+  val trainingSet = 
env.fromCollection(Classification.trainingData).map(_.vector)
+  val testingSet = env.fromElements(DenseVector(0.0, 0.0))
+
+ val knn = KNN()
+    .setK(3)
+    .setBlocks(10)
+    .setDistanceMetric(SquaredEuclideanDistanceMetric())
+    .setUseQuadTree(false)
+    .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+  // run knn join
+  knn.fit(trainingSet)
+  val result = knn.predict(testingSet).collect()
+
+{% endhighlight %}
+
+For more details on the computing KNN with and without and quadtree, here is a 
presentation:
+http://danielblazevski.github.io/

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
new file mode 100644
index 0000000..82f4b88
--- /dev/null
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala.utils._
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric, 
DistanceMetric,
+EuclideanDistanceMetric}
+import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, 
Predictor}
+import org.apache.flink.util.Collector
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+
+import scala.collection.immutable.Vector
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+/** Implements a k-nearest neighbor join.
+  *
+  * Calculates the `k`-nearest neighbor points in the training set for each 
point in the test set.
+  *
+  * @example
+  * {{{
+  *         val trainingDS: DataSet[Vector] = ...
+  *         val testingDS: DataSet[Vector] = ...
+  *
+  *         val knn = KNN()
+  *           .setK(10)
+  *           .setBlocks(5)
+  *           .setDistanceMetric(EuclideanDistanceMetric())
+  *
+  *         knn.fit(trainingDS)
+  *
+  *         val predictionDS: DataSet[(Vector, Array[Vector])] = 
knn.predict(testingDS)
+  * }}}
+  *
+  * =Parameters=
+  *
+  * - [[org.apache.flink.ml.nn.KNN.K]]
+  * Sets the K which is the number of selected points as neighbors. (Default 
value: '''5''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.DistanceMetric]]
+  * Sets the distance metric we use to calculate the distance between two 
points. If no metric is
+  * specified, then 
[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
+  * (Default value: '''EuclideanDistanceMetric()''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.Blocks]]
+  * Sets the number of blocks into which the input data will be split. This 
number should be set
+  * at least to the degree of parallelism. If no value is specified, then the 
parallelism of the
+  * input [[DataSet]] is used as the number of blocks. (Default value: 
'''None''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.UseQuadTreeParam]]
+  * A boolean variable that whether or not to use a Quadtree to partition the 
training set
+  * to potentially simplify the KNN search.  If no value is specified, the 
code will
+  * automatically decide whether or not to use a Quadtree.  Use of a Quadtree 
scales well
+  * with the number of training and testing points, though poorly with the 
dimension.
+  * (Default value:  ```None```)
+  *
+  * - [[org.apache.flink.ml.nn.KNN.SizeHint]]
+  * Specifies whether the training set or test set is small to optimize the 
cross
+  * product operation needed for the KNN search.  If the training set is small
+  * this should be `CrossHint.FIRST_IS_SMALL` and set to 
`CrossHint.SECOND_IS_SMALL`
+  * if the test set is small.
+  * (Default value:  ```None```)
+  *
+  */
+
+class KNN extends Predictor[KNN] {
+
+  import KNN._
+
+  var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
+
+  /** Sets K
+    * @param k the number of selected points as neighbors
+    */
+  def setK(k: Int): KNN = {
+    require(k > 0, "K must be positive.")
+    parameters.add(K, k)
+    this
+  }
+
+  /** Sets the distance metric
+    * @param metric the distance metric to calculate distance between two 
points
+    */
+  def setDistanceMetric(metric: DistanceMetric): KNN = {
+    parameters.add(DistanceMetric, metric)
+    this
+  }
+
+  /** Sets the number of data blocks/partitions
+    * @param n the number of data blocks
+    */
+  def setBlocks(n: Int): KNN = {
+    require(n > 0, "Number of blocks must be positive.")
+    parameters.add(Blocks, n)
+    this
+  }
+
+  /**
+    * Sets the Boolean variable that decides whether to use the QuadTree or not
+    */
+  def setUseQuadTree(useQuadTree: Boolean): KNN = {
+    if (useQuadTree) {
+      
require(parameters(DistanceMetric).isInstanceOf[SquaredEuclideanDistanceMetric] 
||
+        parameters(DistanceMetric).isInstanceOf[EuclideanDistanceMetric])
+    }
+    parameters.add(UseQuadTreeParam, useQuadTree)
+    this
+  }
+
+  /**
+    * Parameter a user can specify if one of the training or test sets are 
small
+    * @param sizeHint
+    * @return
+    */
+  def setSizeHint(sizeHint: CrossHint): KNN = {
+    parameters.add(SizeHint, sizeHint)
+    this
+  }
+
+}
+
+object KNN {
+
+  case object K extends Parameter[Int] {
+    val defaultValue: Option[Int] = Some(5)
+  }
+
+  case object DistanceMetric extends Parameter[DistanceMetric] {
+    val defaultValue: Option[DistanceMetric] = Some(EuclideanDistanceMetric())
+  }
+
+  case object Blocks extends Parameter[Int] {
+    val defaultValue: Option[Int] = None
+  }
+
+  case object UseQuadTreeParam extends Parameter[Boolean] {
+    val defaultValue: Option[Boolean] = None
+  }
+
+  case object SizeHint extends Parameter[CrossHint] {
+    val defaultValue: Option[CrossHint] = None
+  }
+
+  def apply(): KNN = {
+    new KNN()
+  }
+
+  /** [[FitOperation]] which trains a KNN based on the given training data set.
+    * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
+    */
+  implicit def fitKNN[T <: FlinkVector : TypeInformation] = new 
FitOperation[KNN, T] {
+    override def fit(
+                      instance: KNN,
+                      fitParameters: ParameterMap,
+                      input: DataSet[T]): Unit = {
+      val resultParameters = instance.parameters ++ fitParameters
+
+      require(resultParameters.get(K).isDefined, "K is needed for calculation")
+
+      val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
+      val partitioner = FlinkMLTools.ModuloKeyPartitioner
+      val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
+
+      instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, 
Some(partitioner)))
+    }
+  }
+
+  /** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the 
given testing data
+    * set.
+    * @tparam T Subtype of [[Vector]]
+    * @return The given testing data set with k-nearest neighbors
+    */
+  implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
+    new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
+      override def predictDataSet(
+                                   instance: KNN,
+                                   predictParameters: ParameterMap,
+                                   input: DataSet[T]): DataSet[(FlinkVector,
+        Array[FlinkVector])] = {
+        val resultParameters = instance.parameters ++ predictParameters
+
+        instance.trainingSet match {
+          case Some(trainingSet) =>
+            val k = resultParameters.get(K).get
+            val blocks = 
resultParameters.get(Blocks).getOrElse(input.getParallelism)
+            val metric = resultParameters.get(DistanceMetric).get
+            val partitioner = FlinkMLTools.ModuloKeyPartitioner
+
+            // attach unique id for each data
+            val inputWithId: DataSet[(Long, T)] = input.zipWithUniqueId
+
+            // split data into multiple blocks
+            val inputSplit = FlinkMLTools.block(inputWithId, blocks, 
Some(partitioner))
+
+            val sizeHint = resultParameters.get(SizeHint)
+            val crossTuned = sizeHint match {
+              case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
+                trainingSet.crossWithHuge(inputSplit)
+              case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
+                trainingSet.crossWithTiny(inputSplit)
+              case _ => trainingSet.cross(inputSplit)
+            }
+
+            // join input and training set
+            val crossed = crossTuned.mapPartition {
+              (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) 
=> {
+                for ((training, testing) <- iter) {
+                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
+                  // < Ntest*Ntrain, and distance is Euclidean
+                  val useQuadTree = 
resultParameters.get(UseQuadTreeParam).getOrElse(
+                    math.log(4.0) * training.values.head.size + 
math.log(math.log(training.values.length))
+                      < math.log(training.values.length) &&
+                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
+
+                  if (useQuadTree) {
+                    knnQueryWithQuadTree(training.values, testing.values, k, 
metric, out)
+                  } else {
+                    knnQueryBasic(training.values, testing.values, k, metric, 
out)
+                  }
+                }
+              }
+            }
+
+            // group by input vector id and pick k nearest neighbor for each 
group
+            val result = crossed.groupBy(2).sortGroup(3, 
Order.ASCENDING).reduceGroup {
+              (iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
+                if (iter.hasNext) {
+                  val head = iter.next()
+                  val key = head._2
+                  val neighbors: ArrayBuffer[FlinkVector] = 
ArrayBuffer(head._1)
+
+                  for ((vector, _, _, _) <- iter.take(k - 1)) {
+                    // we already took a first element
+                    neighbors += vector
+                  }
+
+                  out.collect(key, neighbors.toArray)
+                }
+              }
+            }
+
+            result
+          case None => throw new RuntimeException("The KNN model has not been 
trained." +
+            "Call first fit before calling the predict operation.")
+
+        }
+      }
+    }
+  }
+
+  def knnQueryWithQuadTree[T <: FlinkVector](
+                                              training: Vector[T],
+                                              testing: Vector[(Long, T)],
+                                              k: Int, metric: DistanceMetric,
+                                              out: Collector[(FlinkVector,
+                                                FlinkVector, Long, Double)]) {
+    /// find a bounding box
+    val MinArr = Array.tabulate(training.head.size)(x => x)
+    val MaxArr = Array.tabulate(training.head.size)(x => x)
+
+    val minVecTrain = MinArr.map(i => training.map(x => x(i)).min - 0.01)
+    val minVecTest = MinArr.map(i => testing.map(x => x._2(i)).min - 0.01)
+    val maxVecTrain = MaxArr.map(i => training.map(x => x(i)).max + 0.01)
+    val maxVecTest = MaxArr.map(i => testing.map(x => x._2(i)).max + 0.01)
+
+    val MinVec = DenseVector(MinArr.map(i => math.min(minVecTrain(i), 
minVecTest(i))))
+    val MaxVec = DenseVector(MinArr.map(i => math.max(maxVecTrain(i), 
maxVecTest(i))))
+
+    //default value of max elements/box is set to max(20,k)
+    val maxPerBox = math.max(k, 20)
+    val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
+
+    val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, 
Double)]()(
+      Ordering.by(_._4))
+
+    for (v <- training) {
+      trainingQuadTree.insert(v)
+    }
+
+    for ((id, vector) <- testing) {
+      //  Find siblings' objects and do local kNN there
+      val siblingObjects =
+        trainingQuadTree.searchNeighborsSiblingQueue(vector)
+
+      // do KNN query on siblingObjects and get max distance of kNN
+      // then rad is good choice for a neighborhood to do a refined
+      // local kNN search
+      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
+      ).sortWith(_ < _).take(k)
+
+      val rad = knnSiblings.last
+      val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
+
+      for (b <- trainingFiltered) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+  def knnQueryBasic[T <: FlinkVector](
+                                       training: Vector[T],
+                                       testing: Vector[(Long, T)],
+                                       k: Int, metric: DistanceMetric,
+                                       out: Collector[(FlinkVector, 
FlinkVector, Long, Double)]) {
+
+    val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, 
Double)]()(
+      Ordering.by(_._4))
+    
+    for ((id, vector) <- testing) {
+      for (b <- training) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
new file mode 100644
index 0000000..d08dcdd
--- /dev/null
+++ 
b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.ml.math.{Breeze, Vector}
+import Breeze._
+
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+EuclideanDistanceMetric, DistanceMetric}
+
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.PriorityQueue
+
+/**
+ * n-dimensional QuadTree data structure; partitions
+ * spatial data for faster queries (e.g. KNN query)
+ * The skeleton of the data structure was initially
+ * based off of the 2D Quadtree found here:
+ * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+ *
+ * Many additional methods were added to the class both for
+ * efficient KNN queries and generalizing to n-dim.
+ *
+ * @param minVec vector of the corner of the bounding box with smallest 
coordinates
+ * @param maxVec vector of the corner of the bounding box with smallest 
coordinates
+ * @param distMetric metric, must be Euclidean or squareEuclidean
+ * @param maxPerBox threshold for number of points in each box before slitting 
a box
+ */
+class QuadTree(
+  minVec: Vector,
+  maxVec: Vector,
+  distMetric: DistanceMetric,
+  maxPerBox: Int) {
+
+  class Node(
+    center: Vector,
+    width: Vector,
+    var children: Seq[Node]) {
+
+    val nodeElements = new ListBuffer[Vector]
+
+    /** for testing purposes only; used in QuadTreeSuite.scala
+      *
+      * @return center and width of the box
+      */
+    def getCenterWidth(): (Vector, Vector) = {
+      (center, width)
+    }
+
+    /** Tests whether the queryPoint is in the node, or a child of that node
+      *
+      * @param queryPoint
+      * @return
+      */
+    def contains(queryPoint: Vector): Boolean = {
+      overlap(queryPoint, 0.0)
+    }
+
+    /** Tests if queryPoint is within a radius of the node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def overlap(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      (0 until queryPoint.size).forall{ i =>
+          (queryPoint(i) - radius < center(i) + width(i) / 2) &&
+            (queryPoint(i) + radius > center(i) - width(i) / 2)
+      }
+    }
+
+    /** Tests if queryPoint is near a node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def isNear(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      minDist(queryPoint) < radius
+    }
+
+    /**
+     * minDist is defined so that every point in the box
+     * has distance to queryPoint greater than minDist
+     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et 
al.)
+     *
+     * @param queryPoint
+     * @return
+     */
+    def minDist(queryPoint: Vector): Double = {
+      val minDist = (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) < center(i) - width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
+        } else if (queryPoint(i) > center(i) + width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
+        } else {
+          0
+        }
+      }.sum
+
+      distMetric match {
+        case _: SquaredEuclideanDistanceMetric => minDist
+        case _: EuclideanDistanceMetric => math.sqrt(minDist)
+        case _ => throw new IllegalArgumentException(s" Error: metric must be" 
+
+          s" Euclidean or SquaredEuclidean!")
+      }
+    }
+
+    /**
+     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
+     * whichChild finds the appropriate index of that Seq.
+     * @param queryPoint
+     * @return
+     */
+    def whichChild(queryPoint: Vector): Int = {
+      (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) > center(i)) {
+          scala.math.pow(2, queryPoint.size - 1 - i).toInt
+        } else {
+          0
+        }
+      }.sum
+    }
+
+    /** Makes children nodes by partitioning the box into equal sub-boxes
+      * and adding a node for each sub-box
+      */
+    def makeChildren() {
+      val centerClone = center.copy
+      val cPart = partitionBox(centerClone, width)
+      val mappedWidth = 0.5 * width.asBreeze
+      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
+    }
+
+    /**
+     * Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
+     * plane through the center of the box keeping the n-th coordinate fixed,
+     * then shifting it in the n-th direction up and down
+     * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
+     *
+     * @param center the center of the box
+     * @param width a vector of lengths of each dimension of the box
+     * @return
+     */
+    def partitionBox(
+      center: Vector,
+      width: Vector): Seq[Vector] = {
+      def partitionHelper(
+        box: Seq[Vector],
+        dim: Int): Seq[Vector] = {
+        if (dim >= width.size) {
+          box
+        } else {
+          val newBox = box.flatMap {
+            vector =>
+              val (up, down) = (vector.copy, vector)
+              up.update(dim, up(dim) - width(dim) / 4)
+              down.update(dim, down(dim) + width(dim) / 4)
+
+              Seq(up, down)
+          }
+          partitionHelper(newBox, dim + 1)
+        }
+      }
+      partitionHelper(Seq(center), 0)
+    }
+  }
+
+
+  val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
+    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
+
+  /**
+   * simple printing of tree for testing/debugging
+   */
+  def printTree(): Unit = {
+    printTreeRecur(root)
+  }
+
+  def printTreeRecur(node: Node) {
+    if (node.children != null) {
+      for (c <- node.children) {
+        printTreeRecur(c)
+      }
+    } else {
+      println("printing tree: n.nodeElements " + node.nodeElements)
+    }
+  }
+
+  /**
+   * Recursively adds an object to the tree
+   * @param queryPoint
+   */
+  def insert(queryPoint: Vector) {
+    insertRecur(queryPoint, root)
+  }
+
+  private def insertRecur(
+    queryPoint: Vector,
+    node: Node) {
+    if (node.children == null) {
+      if (node.nodeElements.length < maxPerBox) {
+        node.nodeElements += queryPoint
+      } else {
+        node.makeChildren()
+        for (o <- node.nodeElements) {
+          insertRecur(o, node.children(node.whichChild(o)))
+        }
+        node.nodeElements.clear()
+        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+      }
+    } else {
+      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+    }
+  }
+
+  /**
+   * Used to zoom in on a region near a test point for a fast KNN query.
+   * This capability is used in the KNN query to find k "near" neighbors 
n_1,...,n_k, from
+   * which one computes the max distance D_s to queryPoint.  D_s is then used 
during the
+   * kNN query to find all points within a radius D_s of queryPoint using 
searchNeighbors.
+   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of 
the leaf
+   * nodes of the minimal bounding box of the queryPoint. The priority of a 
leaf node
+   * is an appropriate notion of the distance between the test point and the 
node,
+   * which is defined by minDist(queryPoint),
+   *
+   * @param queryPoint a test point for which the method finds the minimal 
bounding
+   *                   box that queryPoint lies in and returns elements in 
that boxes
+   *                   siblings' leaf nodes
+   * @return
+   */
+  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    // edge case when the main box has not been partitioned at all
+    if (root.children == null) {
+      root.nodeElements.clone()
+    } else {
+      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => 
x._1))
+      searchRecurSiblingQueue(queryPoint, root, nodeQueue)
+
+      var count = 0
+      while (count < maxPerBox) {
+        val dq = nodeQueue.dequeue()
+        if (dq._2.nodeElements.nonEmpty) {
+          ret ++= dq._2.nodeElements
+          count += dq._2.nodeElements.length
+        }
+      }
+      ret
+    }
+  }
+
+  /**
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based 
on their
+   *                  distance to node as defined by minDist
+   */
+  private def searchRecurSiblingQueue(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children != null) {
+      for (child <- node.children; if child.contains(queryPoint)) {
+        if (child.children == null) {
+          for (c <- node.children) {
+            minNodes(queryPoint, c, nodeQueue)
+          }
+        } else {
+          searchRecurSiblingQueue(queryPoint, child, nodeQueue)
+        }
+      }
+    }
+  }
+
+  /**
+   * Goes down to minimal bounding box of queryPoint, and add elements to 
nodeQueue
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue PriorityQueue that stores all points in minimal bounding 
box of queryPoint
+   */
+  private def minNodes(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children == null) {
+      nodeQueue += ((-node.minDist(queryPoint), node))
+    } else {
+      for (c <- node.children) {
+        minNodes(queryPoint, c, nodeQueue)
+      }
+    }
+  }
+
+  /** Finds all objects within a neighborhood of queryPoint of a specified 
radius
+    * scope is modified from original 2D version in:
+    * 
http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+    *
+    * original version only looks in minimal box; for the KNN Query, we look at
+    * all nearby boxes. The radius is determined from 
searchNeighborsSiblingQueue
+    * by defining a min-heap on the leaf nodes
+    *
+    * @param queryPoint
+    * @param radius
+    * @return all points within queryPoint with given radius
+    */
+  def searchNeighbors(
+    queryPoint: Vector,
+    radius: Double): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    searchRecur(queryPoint, radius, root, ret)
+    ret
+  }
+
+  private def searchRecur(
+    queryPoint: Vector,
+    radius: Double,
+    node: Node,
+    ret: ListBuffer[Vector]) {
+    if (node.children == null) {
+      ret ++= node.nodeElements
+    } else {
+      for (child <- node.children; if child.isNear(queryPoint, radius)) {
+        searchRecur(queryPoint, radius, child, ret)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
new file mode 100644
index 0000000..63e412a
--- /dev/null
+++ 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.classification.Classification
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.metrics.distances.{ManhattanDistanceMetric,
+SquaredEuclideanDistanceMetric}
+import org.apache.flink.test.util.FlinkTestBase
+import org.scalatest.{FlatSpec, Matchers}
+
+class KNNITSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The KNN Join Implementation"
+
+  it should "throw an exception when the given K is not valid" in {
+    intercept[IllegalArgumentException] {
+      KNN().setK(0)
+    }
+  }
+
+  it should "throw an exception when the given count of blocks is not valid" 
in {
+    intercept[IllegalArgumentException] {
+      KNN().setBlocks(0)
+    }
+  }
+
+  val env = ExecutionEnvironment.getExecutionEnvironment
+
+  // prepare data
+  val trainingSet = 
env.fromCollection(Classification.trainingData).map(_.vector)
+  val testingSet = env.fromElements(DenseVector(0.0, 0.0))
+
+  // calculate answer
+  val answer = Classification.trainingData.map {
+    v => (v.vector, SquaredEuclideanDistanceMetric().distance(DenseVector(0.0, 
0.0), v.vector))
+  }.sortBy(_._2).take(3).map(_._1).toArray
+
+  it should "calculate kNN join correctly without using a Quadtree" in {
+
+    val knn = KNN()
+      .setK(3)
+      .setBlocks(10)
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(false)
+      .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+    // run knn join
+    knn.fit(trainingSet)
+    val result = knn.predict(testingSet).collect()
+
+    result.size should be(1)
+    result.head._1 should be(DenseVector(0.0, 0.0))
+    result.head._2 should be(answer)
+  }
+
+  it should "calculate kNN join correctly with a Quadtree" in {
+
+    val knn = KNN()
+      .setK(3)
+      .setBlocks(2) // blocks set to 2 to make sure initial quadtree box is 
partitioned
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(true)
+      .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+    // run knn join
+    knn.fit(trainingSet)
+    val result = knn.predict(testingSet).collect()
+
+    result.size should be(1)
+    result.head._1 should be(DenseVector(0.0, 0.0))
+    result.head._2 should be(answer)
+  }
+
+  it should "throw an exception when using a Quadtree with an incompatible 
metric" in {
+    intercept[IllegalArgumentException] {
+      val knn = KNN()
+        .setK(3)
+        .setBlocks(10)
+        .setDistanceMetric(ManhattanDistanceMetric())
+        .setUseQuadTree(true)
+
+      // run knn join
+      knn.fit(trainingSet)
+      val result = knn.predict(testingSet).collect()
+
+    }
+  }
+
+}
+

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
new file mode 100644
index 0000000..8be5c6e
--- /dev/null
+++ 
b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
+import org.apache.flink.test.util.FlinkTestBase
+import org.apache.flink.ml.math.{Vector, DenseVector}
+
+import org.scalatest.{Matchers, FlatSpec}
+
+/** Test of Quadtree class
+  * Constructor for the Quadtree class:
+  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
+  *
+  */
+
+class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The QuadTree Class"
+
+  it should "partition into equal size sub-boxes and search for nearby objects 
properly" in {
+
+    val minVec = DenseVector(-1.0, -0.5)
+    val maxVec = DenseVector(1.0, 0.5)
+
+    val myTree = new QuadTree(minVec, maxVec, EuclideanDistanceMetric(), 3)
+
+    myTree.insert(DenseVector(-0.25, 0.3).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
+
+    var a = myTree.root.getCenterWidth()
+
+    /** Tree will partition once the 4th point is added
+      */
+
+    myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
+    myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
+
+    /**
+     * Exact values of (centers,dimensions) of root + children nodes, to test
+     * partitionBox and makeChildren methods; exact values are given to avoid
+     * essentially copying and pasting the code to automatically generate them
+     * from minVec/maxVec
+     */
+
+    val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)),
+      (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
+    )
+
+    /**
+     * (centers,dimensions) computed from QuadTree.makeChildren
+     */
+
+    var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)))
+    for (child <- myTree.root.children) {
+      computedCentersLength += 
child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
+    }
+
+
+    /**
+     * Tests search for nearby neighbors, make sure the right object is 
contained in neighbor
+      * search the neighbor search will contain more points
+     */
+    val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 
0.45001), 0.001)
+    val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
+
+    /**
+     * Test ability to get all objects in minimal bounding box + objects in 
siblings' block method
+     * In this case, drawing a picture of the QuadTree shows that
+     * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
+     * are objects near (-0.2001, 0.31001)
+     */
+
+    val siblingsObjectsComputed = 
myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
+    val isSiblingsInSearch = 
siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.289))
+
+    computedCentersLength should be(knownCentersLengths)
+    isNeighborInSearch should be(true)
+    isSiblingsInSearch should be(true)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
index 35073b6..6d563e9 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
@@ -20,14 +20,20 @@ package org.apache.flink.ml.nn
 
 import org.apache.flink.api.common.operators.Order
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.scala.DataSetUtils._
+//import org.apache.flink.api.scala.DataSetUtils._
+import org.apache.flink.api.scala.utils._
 import org.apache.flink.api.scala._
 import org.apache.flink.ml.common._
-import org.apache.flink.ml.math.Vector
-import org.apache.flink.ml.metrics.distances.{DistanceMetric, 
EuclideanDistanceMetric}
+import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+DistanceMetric, EuclideanDistanceMetric}
 import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, 
Predictor}
 import org.apache.flink.util.Collector
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
 
+import org.apache.flink.ml.nn.util.QuadTree
+
+import scala.collection.immutable.Vector
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
@@ -38,17 +44,17 @@ import scala.reflect.ClassTag
   *
   * @example
   * {{{
-  *     val trainingDS: DataSet[Vector] = ...
-  *     val testingDS: DataSet[Vector] = ...
+  *       val trainingDS: DataSet[Vector] = ...
+  *       val testingDS: DataSet[Vector] = ...
   *
-  *     val knn = KNN()
-  *       .setK(10)
-  *       .setBlocks(5)
-  *       .setDistanceMetric(EuclideanDistanceMetric())
+  *       val knn = KNN()
+  *         .setK(10)
+  *         .setBlocks(5)
+  *         .setDistanceMetric(EuclideanDistanceMetric())
   *
-  *     knn.fit(trainingDS)
+  *       knn.fit(trainingDS)
   *
-  *     val predictionDS: DataSet[(Vector, Array[Vector])] = 
knn.predict(testingDS)
+  *       val predictionDS: DataSet[(Vector, Array[Vector])] = 
knn.predict(testingDS)
   * }}}
   *
   * =Parameters=
@@ -67,11 +73,12 @@ import scala.reflect.ClassTag
   * (Default value: '''EuclideanDistanceMetric()''')
   *
   */
+
 class KNN extends Predictor[KNN] {
 
   import KNN._
 
-  var trainingSet: Option[DataSet[Block[Vector]]] = None
+  var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
 
   /** Sets K
     * @param k the number of selected points as neighbors
@@ -98,6 +105,25 @@ class KNN extends Predictor[KNN] {
     parameters.add(Blocks, n)
     this
   }
+
+  /**
+   * Sets the Boolean variable that decides whether to use the QuadTree or not
+   */
+  def setUseQuadTree(UseQuadTree: Boolean): KNN = {
+    parameters.add(UseQuadTreeParam, UseQuadTree)
+    this
+  }
+
+  /**
+   * Parameter a user can specify if one of the training or test sets are small
+   * @param sizeHint
+   * @return
+   */
+  def setSizeHint(sizeHint: CrossHint): KNN = {
+    parameters.add(SizeHint, sizeHint)
+    this
+  }
+
 }
 
 object KNN {
@@ -114,6 +140,14 @@ object KNN {
     val defaultValue: Option[Int] = None
   }
 
+  case object UseQuadTreeParam extends Parameter[Boolean] {
+    val defaultValue: Option[Boolean] = None
+  }
+
+  case object SizeHint extends Parameter[CrossHint] {
+    val defaultValue: Option[CrossHint] = None
+  }
+
   def apply(): KNN = {
     new KNN()
   }
@@ -121,18 +155,18 @@ object KNN {
   /** [[FitOperation]] which trains a KNN based on the given training data set.
     * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
     */
-  implicit def fitKNN[T <: Vector : TypeInformation] = new FitOperation[KNN, 
T] {
+  implicit def fitKNN[T <: FlinkVector : TypeInformation] = new 
FitOperation[KNN, T] {
     override def fit(
-        instance: KNN,
-        fitParameters: ParameterMap,
-        input: DataSet[T]): Unit = {
+      instance: KNN,
+      fitParameters: ParameterMap,
+      input: DataSet[T]): Unit = {
       val resultParameters = instance.parameters ++ fitParameters
 
       require(resultParameters.get(K).isDefined, "K is needed for calculation")
 
       val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
       val partitioner = FlinkMLTools.ModuloKeyPartitioner
-      val inputAsVector = input.asInstanceOf[DataSet[Vector]]
+      val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
 
       instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, 
Some(partitioner)))
     }
@@ -143,12 +177,13 @@ object KNN {
     * @tparam T Subtype of [[Vector]]
     * @return The given testing data set with k-nearest neighbors
     */
-  implicit def predictValues[T <: Vector : ClassTag : TypeInformation] = {
-    new PredictDataSetOperation[KNN, T, (Vector, Array[Vector])] {
+  implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
+    new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
       override def predictDataSet(
-          instance: KNN,
-          predictParameters: ParameterMap,
-          input: DataSet[T]): DataSet[(Vector, Array[Vector])] = {
+        instance: KNN,
+        predictParameters: ParameterMap,
+        input: DataSet[T]): DataSet[(FlinkVector,
+        Array[FlinkVector])] = {
         val resultParameters = instance.parameters ++ predictParameters
 
         instance.trainingSet match {
@@ -164,24 +199,40 @@ object KNN {
             // split data into multiple blocks
             val inputSplit = FlinkMLTools.block(inputWithId, blocks, 
Some(partitioner))
 
+            val sizeHint = resultParameters.get(SizeHint)
+            val crossTuned = sizeHint match {
+              case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
+                trainingSet.crossWithHuge(inputSplit)
+              case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
+                trainingSet.crossWithTiny(inputSplit)
+              case _ => trainingSet.cross(inputSplit)
+            }
+
             // join input and training set
-            val crossed = trainingSet.cross(inputSplit).mapPartition {
-              (iter, out: Collector[(Vector, Vector, Long, Double)]) => {
+            val crossed = crossTuned.mapPartition {
+              (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) 
=> {
                 for ((training, testing) <- iter) {
-                  val queue = mutable.PriorityQueue[(Vector, Vector, Long, 
Double)]()(
+                  val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, 
Long, Double)]()(
                     Ordering.by(_._4))
 
-                  for (a <- testing.values; b <- training.values) {
-                    // (training vector, input vector, input key, distance)
-                    queue.enqueue((b, a._2, a._1, metric.distance(b, a._2)))
+                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
+                  // < Ntest*Ntrain, and distance is Euclidean
+                  val useQuadTree = 
resultParameters.get(UseQuadTreeParam).getOrElse(
+                    training.values.head.size + 
math.log(math.log(training.values.length) /
+                      math.log(4.0)) < math.log(training.values.length) / 
math.log(4.0) &&
+                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
 
-                    if (queue.size > k) {
-                      queue.dequeue()
+                  if (useQuadTree) {
+                    if (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                      metric.isInstanceOf[SquaredEuclideanDistanceMetric]){
+                      knnQueryWithQuadTree(training.values, testing.values, k, 
metric, queue, out)
+                    } else {
+                      throw new IllegalArgumentException(s" Error: metric must 
be" +
+                        s" Euclidean or SquaredEuclidean!")
                     }
-                  }
-
-                  for (v <- queue) {
-                    out.collect(v)
+                  } else {
+                    knnQueryBasic(training.values, testing.values, k, metric, 
queue, out)
                   }
                 }
               }
@@ -189,13 +240,14 @@ object KNN {
 
             // group by input vector id and pick k nearest neighbor for each 
group
             val result = crossed.groupBy(2).sortGroup(3, 
Order.ASCENDING).reduceGroup {
-              (iter, out: Collector[(Vector, Array[Vector])]) => {
+              (iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
                 if (iter.hasNext) {
                   val head = iter.next()
                   val key = head._2
-                  val neighbors: ArrayBuffer[Vector] = ArrayBuffer(head._1)
+                  val neighbors: ArrayBuffer[FlinkVector] = 
ArrayBuffer(head._1)
 
-                  for ((vector, _, _, _) <- iter.take(k - 1)) { // we already 
took a first element
+                  for ((vector, _, _, _) <- iter.take(k - 1)) {
+                    // we already took a first element
                     neighbors += vector
                   }
 
@@ -206,9 +258,88 @@ object KNN {
 
             result
           case None => throw new RuntimeException("The KNN model has not been 
trained." +
-              "Call first fit before calling the predict operation.")
+            "Call first fit before calling the predict operation.")
+
         }
       }
     }
   }
+
+  def knnQueryWithQuadTree[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int, metric: DistanceMetric,
+    queue: mutable.PriorityQueue[(FlinkVector,
+      FlinkVector, Long, Double)],
+    out: Collector[(FlinkVector,
+      FlinkVector, Long, Double)]) {
+    /// find a bounding box
+    val MinArr = Array.tabulate(training.head.size)(x => x)
+    val MaxArr = Array.tabulate(training.head.size)(x => x)
+
+    val minVecTrain = MinArr.map(i => training.map(x => x(i)).min - 0.01)
+    val minVecTest = MinArr.map(i => testing.map(x => x._2(i)).min - 0.01)
+    val maxVecTrain = MaxArr.map(i => training.map(x => x(i)).min + 0.01)
+    val maxVecTest = MaxArr.map(i => testing.map(x => x._2(i)).min + 0.01)
+
+    val MinVec = DenseVector(MinArr.map(i => Array(minVecTrain(i), 
minVecTest(i)).min))
+    val MaxVec = DenseVector(MinArr.map(i => Array(maxVecTrain(i), 
maxVecTest(i)).max))
+
+    //default value of max elements/box is set to max(20,k)
+    val maxPerBox = Array(k, 20).max
+    val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
+
+    for (v <- training) {
+      trainingQuadTree.insert(v)
+    }
+
+    for ((id, vector) <- testing) {
+      //  Find siblings' objects and do local kNN there
+      val siblingObjects =
+        trainingQuadTree.searchNeighborsSiblingQueue(vector)
+
+      // do KNN query on siblingObjects and get max distance of kNN
+      // then rad is good choice for a neighborhood to do a refined
+      // local kNN search
+      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
+      ).sortWith(_ < _).take(k)
+
+      val rad = knnSiblings.last
+      val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
+
+      for (b <- trainingFiltered) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+  def knnQueryBasic[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int, metric: DistanceMetric,
+    queue: mutable.PriorityQueue[(FlinkVector,
+      FlinkVector, Long, Double)],
+    out: Collector[(FlinkVector, FlinkVector, Long, Double)]) {
+
+    for ((id, vector) <- testing) {
+      for (b <- training) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
new file mode 100644
index 0000000..0b37313
--- /dev/null
+++ 
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn.util
+
+import org.apache.flink.ml.math.{Breeze, Vector}
+import Breeze._
+
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+EuclideanDistanceMetric, DistanceMetric}
+
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.PriorityQueue
+
+/**
+ * n-dimensional QuadTree data structure; partitions
+ * spatial data for faster queries (e.g. KNN query)
+ * The skeleton of the data structure was initially
+ * based off of the 2D Quadtree found here:
+ * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+ *
+ * Many additional methods were added to the class both for
+ * efficient KNN queries and generalizing to n-dim.
+ *
+ * @param minVec vector of the corner of the bounding box with smallest 
coordinates
+ * @param maxVec vector of the corner of the bounding box with smallest 
coordinates
+ * @param distMetric metric, must be Euclidean or squareEuclidean
+ * @param maxPerBox threshold for number of points in each box before slitting 
a box
+ */
+class QuadTree(
+  minVec: Vector,
+  maxVec: Vector,
+  distMetric: DistanceMetric,
+  maxPerBox: Int) {
+
+  class Node(
+    center: Vector,
+    width: Vector,
+    var children: Seq[Node]) {
+
+    val nodeElements = new ListBuffer[Vector]
+
+    /** for testing purposes only; used in QuadTreeSuite.scala
+      *
+      * @return center and width of the box
+      */
+    def getCenterWidth(): (Vector, Vector) = {
+      (center, width)
+    }
+
+    def contains(queryPoint: Vector): Boolean = {
+      overlap(queryPoint, 0.0)
+    }
+
+    /** Tests if queryPoint is within a radius of the node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def overlap(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      val count = (0 until queryPoint.size).filter { i =>
+        (queryPoint(i) - radius < center(i) + width(i) / 2) &&
+          (queryPoint(i) + radius > center(i) - width(i) / 2)
+      }.size
+
+      count == queryPoint.size
+    }
+
+    /** Tests if queryPoint is near a node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def isNear(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      minDist(queryPoint) < radius
+    }
+
+    /**
+     * minDist is defined so that every point in the box
+     * has distance to queryPoint greater than minDist
+     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et 
al.)
+     *
+     * @param queryPoint
+     * @return
+     */
+    def minDist(queryPoint: Vector): Double = {
+      val minDist = (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) < center(i) - width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
+        } else if (queryPoint(i) > center(i) + width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
+        } else {
+          0
+        }
+      }.sum
+
+      distMetric match {
+        case _: SquaredEuclideanDistanceMetric => minDist
+        case _: EuclideanDistanceMetric => math.sqrt(minDist)
+        case _ => throw new IllegalArgumentException(s" Error: metric must be" 
+
+          s" Euclidean or SquaredEuclidean!")
+      }
+    }
+
+    /**
+     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
+     * whichChild finds the appropriate index of that Seq.
+     * @param queryPoint
+     * @return
+     */
+    def whichChild(queryPoint: Vector): Int = {
+      (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) > center(i)) {
+          Math.pow(2, queryPoint.size - 1 - i).toInt
+        } else {
+          0
+        }
+      }.sum
+    }
+
+    def makeChildren() {
+      val centerClone = center.copy
+      val cPart = partitionBox(centerClone, width)
+      val mappedWidth = 0.5 * width.asBreeze
+      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
+    }
+
+    /**
+     * Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
+     * plane through the center of the box keeping the n-th coordinate fixed,
+     * then shifting it in the n-th direction up and down
+     * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
+     *
+     * @param center the center of the box
+     * @param width a vector of lengths of each dimension of the box
+     * @return
+     */
+    def partitionBox(
+      center: Vector,
+      width: Vector): Seq[Vector] = {
+      def partitionHelper(
+        box: Seq[Vector],
+        dim: Int): Seq[Vector] = {
+        if (dim >= width.size) {
+          box
+        } else {
+          val newBox = box.flatMap {
+            vector =>
+              val (up, down) = (vector.copy, vector)
+              up.update(dim, up(dim) - width(dim) / 4)
+              down.update(dim, down(dim) + width(dim) / 4)
+
+              Seq(up, down)
+          }
+          partitionHelper(newBox, dim + 1)
+        }
+      }
+      partitionHelper(Seq(center), 0)
+    }
+  }
+
+
+  val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
+    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
+
+  /**
+   * simple printing of tree for testing/debugging
+   */
+  def printTree(): Unit = {
+    printTreeRecur(root)
+  }
+
+  def printTreeRecur(node: Node) {
+    if (node.children != null) {
+      for (c <- node.children) {
+        printTreeRecur(c)
+      }
+    } else {
+      println("printing tree: n.nodeElements " + node.nodeElements)
+    }
+  }
+
+  /**
+   * Recursively adds an object to the tree
+   * @param queryPoint
+   */
+  def insert(queryPoint: Vector) {
+    insertRecur(queryPoint, root)
+  }
+
+  private def insertRecur(
+    queryPoint: Vector,
+    node: Node) {
+    if (node.children == null) {
+      if (node.nodeElements.length < maxPerBox) {
+        node.nodeElements += queryPoint
+      } else {
+        node.makeChildren()
+        for (o <- node.nodeElements) {
+          insertRecur(o, node.children(node.whichChild(o)))
+        }
+        node.nodeElements.clear()
+        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+      }
+    } else {
+      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+    }
+  }
+
+  /**
+   * Used to zoom in on a region near a test point for a fast KNN query.
+   * This capability is used in the KNN query to find k "near" neighbors 
n_1,...,n_k, from
+   * which one computes the max distance D_s to queryPoint.  D_s is then used 
during the
+   * kNN query to find all points within a radius D_s of queryPoint using 
searchNeighbors.
+   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of 
the leaf
+   * nodes of the minimal bounding box of the queryPoint. The priority of a 
leaf node
+   * is an appropriate notion of the distance between the test point and the 
node,
+   * which is defined by minDist(queryPoint),
+   *
+   * @param queryPoint a test point for which the method finds the minimal 
bounding
+   *                   box that queryPoint lies in and returns elements in 
that boxes
+   *                   siblings' leaf nodes
+   * @return
+   */
+  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    // edge case when the main box has not been partitioned at all
+    if (root.children == null) {
+      root.nodeElements.clone()
+    } else {
+      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => 
x._1))
+      searchRecurSiblingQueue(queryPoint, root, nodeQueue)
+
+      var count = 0
+      while (count < maxPerBox) {
+        val dq = nodeQueue.dequeue()
+        if (dq._2.nodeElements.nonEmpty) {
+          ret ++= dq._2.nodeElements
+          count += dq._2.nodeElements.length
+        }
+      }
+      ret
+    }
+  }
+
+  /**
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based 
on their
+   *                  distance to node as defined by minDist
+   */
+  private def searchRecurSiblingQueue(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children != null) {
+      for (child <- node.children; if child.contains(queryPoint)) {
+        if (child.children == null) {
+          for (c <- node.children) {
+            minNodes(queryPoint, c, nodeQueue)
+          }
+        } else {
+          searchRecurSiblingQueue(queryPoint, child, nodeQueue)
+        }
+      }
+    }
+  }
+
+  /**
+   * Goes down to minimal bounding box of queryPoint, and add elements to 
nodeQueue
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue PriorityQueue that stores all points in minimal bounding 
box of queryPoint
+   */
+  private def minNodes(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children == null) {
+      nodeQueue += ((-node.minDist(queryPoint), node))
+    } else {
+      for (c <- node.children) {
+        minNodes(queryPoint, c, nodeQueue)
+      }
+    }
+  }
+
+  /** Finds all objects within a neigiborhood of queryPoint of a specified 
radius
+    * scope is modified from original 2D version in:
+    * 
http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+    *
+    * original version only looks in minimal box; for the KNN Query, we look at
+    * all nearby boxes. The radius is determined from 
searchNeighborsSiblingQueue
+    * by defining a min-heap on the leaf nodes
+    *
+    * @param queryPoint
+    * @param radius
+    * @return all points within queryPoint with given radius
+    */
+  def searchNeighbors(
+    queryPoint: Vector,
+    radius: Double): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    searchRecur(queryPoint, radius, root, ret)
+    ret
+  }
+
+  private def searchRecur(
+    queryPoint: Vector,
+    radius: Double,
+    node: Node,
+    ret: ListBuffer[Vector]) {
+    if (node.children == null) {
+      ret ++= node.nodeElements
+    } else {
+      for (child <- node.children; if child.isNear(queryPoint, radius)) {
+        searchRecur(queryPoint, radius, child, ret)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
index 107724b..350af95 100644
--- 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
+++ 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
@@ -53,9 +53,10 @@ class KNNITSuite extends FlatSpec with Matchers with 
FlinkTestBase {
     }.sortBy(_._2).take(3).map(_._1).toArray
 
     val knn = KNN()
-        .setK(3)
-        .setBlocks(10)
-        .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setK(3)
+      .setBlocks(10)
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(true)
 
     // run knn join
     knn.fit(trainingSet)

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
new file mode 100644
index 0000000..9b84a80
--- /dev/null
+++ 
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
+import org.apache.flink.ml.nn.util.QuadTree
+import org.apache.flink.test.util.FlinkTestBase
+import org.apache.flink.ml.math.{Breeze, Vector, DenseVector}
+
+import org.scalatest.{Matchers, FlatSpec}
+
+/** Test of Quadtree class
+  * Constructor for the Quadtree class:
+  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
+  *
+  */
+
+class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The QuadTree Class"
+
+  it should "partition into equal size sub-boxes and search for nearby objects 
properly" in {
+
+    val minVec = DenseVector(-1.0, -0.5)
+    val maxVec = DenseVector(1.0, 0.5)
+
+    val myTree = new QuadTree(minVec, maxVec, EuclideanDistanceMetric(), 3)
+
+    myTree.insert(DenseVector(-0.25, 0.3).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
+
+    var a = myTree.root.getCenterWidth()
+
+    /** Tree will partition once the 4th point is added
+      */
+
+    myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
+    myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
+
+    /**
+     * Exact values of (centers,dimensions) of root + children nodes, to test
+     * partitionBox and makeChildren methods; exact values are given to avoid
+     * essentially copying and pasting the code to automatically generate them
+     * from minVec/maxVec
+     */
+
+    val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)),
+      (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
+    )
+
+    /**
+     * (centers,dimensions) computed from QuadTree.makeChildren
+     */
+
+    var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 
1.0)))
+    for (child <- myTree.root.children) {
+      computedCentersLength += 
child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
+    }
+
+
+    /**
+     * Tests search for nearby neighbors, make sure the right object is 
contained in neighbor search
+     * the neighbor search will contain more points
+     */
+    val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 
0.45001), 0.001)
+    val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
+
+    /**
+     * Test ability to get all objects in minimal bounding box + objects in 
siblings' block method
+     * In this case, drawing a picture of the QuadTree shows that
+     * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
+     * are objects near (-0.2001, 0.31001)
+     */
+
+    val siblingsObjectsComputed = 
myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
+    val isSiblingsInSearch = 
siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.289))
+
+    computedCentersLength should be(knownCentersLengths)
+    isNeighborInSearch should be(true)
+    isSiblingsInSearch should be(true)
+  }
+}

Reply via email to