[ 
https://issues.apache.org/jira/browse/FLINK-1745?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15049554#comment-15049554
 ] 

ASF GitHub Bot commented on FLINK-1745:
---------------------------------------

Github user danielblazevski commented on a diff in the pull request:

    https://github.com/apache/flink/pull/1220#discussion_r47166211
  
    --- Diff: 
flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala ---
    @@ -0,0 +1,340 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.ml.nn.util
    +
    +import org.apache.flink.ml.math.{Breeze, Vector}
    +import Breeze._
    +
    +import 
org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
    +EuclideanDistanceMetric, DistanceMetric}
    +
    +import scala.collection.mutable.ListBuffer
    +import scala.collection.mutable.PriorityQueue
    +
    +/**
    + * n-dimensional QuadTree data structure; partitions
    + * spatial data for faster queries (e.g. KNN query)
    + * The skeleton of the data structure was initially
    + * based off of the 2D Quadtree found here:
    + * 
http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
    + *
    + * Many additional methods were added to the class both for
    + * efficient KNN queries and generalizing to n-dim.
    + *
    + * @param minVec vector of the corner of the bounding box with smallest 
coordinates
    + * @param maxVec vector of the corner of the bounding box with smallest 
coordinates
    + * @param distMetric metric, must be Euclidean or squareEuclidean
    + * @param maxPerBox threshold for number of points in each box before 
slitting a box
    + */
    +class QuadTree(minVec: Vector, maxVec: Vector, distMetric: DistanceMetric, 
maxPerBox: Int){
    +
    +  class Node(center: Vector, width: Vector, var children: Seq[Node]) {
    +
    +    val nodeElements = new ListBuffer[Vector]
    +
    +    /** for testing purposes only; used in QuadTreeSuite.scala
    +      *
    +      * @return center and width of the box
    +      */
    +    def getCenterWidth(): (Vector, Vector) = {
    +      (center, width)
    +    }
    +
    +    def contains(queryPoint: Vector): Boolean = {
    +      overlap(queryPoint, 0.0)
    +    }
    +
    +    /** Tests if queryPoint is within a radius of the node
    +      *
    +      * @param queryPoint
    +      * @param radius
    +      * @return
    +      */
    +    def overlap(queryPoint: Vector, radius: Double): Boolean = {
    +      var count = 0
    +      for (i <- 0 to queryPoint.size - 1) {
    +        if (queryPoint(i) - radius < center(i) + width(i) / 2 &&
    +          queryPoint(i) + radius > center(i) - width(i) / 2) {
    +          count += 1
    +        }
    +      }
    +
    +      if (count == queryPoint.size) {
    +        true
    +      } else {
    +        false
    +      }
    +    }
    +
    +    /** Tests if queryPoint is near a node
    +      *
    +      * @param queryPoint
    +      * @param radius
    +      * @return
    +      */
    +    def isNear(queryPoint: Vector, radius: Double): Boolean = {
    +      if (minDist(queryPoint) < radius) {
    +        true
    +      } else {
    +        false
    +      }
    +    }
    +
    +    /**
    +     * used in error handling when computing minDist to make sure
    +     * distMetric is Euclidean or SquaredEuclidean
    +     * @param message
    +     */
    +    case class metricException(message: String) extends Exception(message)
    +
    +    /**
    +     * minDist is defined so that every point in the box
    +     * has distance to queryPoint greater than minDist
    +     * (minDist adopted from "Nearest Neighbors Queries" by N. 
Roussopoulos et al.)
    +     *
    +     * @param queryPoint
    +     * @return
    +     */
    +
    +    def minDist(queryPoint: Vector): Double = {
    +      var minDist = 0.0
    +      for (i <- 0 to queryPoint.size - 1) {
    +        if (queryPoint(i) < center(i) - width(i) / 2) {
    +          minDist += math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
    +        } else if (queryPoint(i) > center(i) + width(i) / 2) {
    +          minDist += math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
    +        }
    +      }
    +
    +      if (distMetric.isInstanceOf[SquaredEuclideanDistanceMetric]) {
    +        minDist
    +      } else if (distMetric.isInstanceOf[EuclideanDistanceMetric]) {
    +        math.sqrt(minDist)
    +      } else{
    +        throw metricException(s" Error: metric must be Euclidean or 
SquaredEuclidean!")
    +      }
    +    }
    +
    +    /**
    +     * Finds which child queryPoint lies in.  node.children is a 
Seq[Node], and
    +     * whichChild finds the appropriate index of that Seq.
    +     * @param queryPoint
    +     * @return
    +     */
    +    def whichChild(queryPoint: Vector): Int = {
    +      var count = 0
    +      for (i <- 0 to queryPoint.size - 1) {
    +        if (queryPoint(i) > center(i)) {
    +          count += Math.pow(2, queryPoint.size -1 - i).toInt
    +        }
    +      }
    +      count
    +    }
    +
    +    def makeChildren() {
    +      val centerClone = center.copy
    +      val cPart = partitionBox(centerClone, width)
    +      val mappedWidth = 0.5*width.asBreeze
    +      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
    +
    +    }
    +
    +    /**
    +     * Recursive function that partitions a n-dim box by taking the (n-1) 
dimensional
    +     * plane through the center of the box keeping the n-th coordinate 
fixed,
    +     * then shifting it in the n-th direction up and down
    +     * and recursively applying partitionBox to the two shifted (n-1) 
dimensional planes.
    +     *
    +     * @param center the center of the box
    +     * @param width a vector of lengths of each dimension of the box
    +     * @return
    +     */
    +    def partitionBox(center: Vector, width: Vector): Seq[Vector] = {
    +
    +      def partitionHelper(box: Seq[Vector], dim: Int): Seq[Vector] = {
    +        if (dim >= width.size) {
    +          box
    +        } else {
    +          val newBox = box.flatMap {
    +            vector =>
    +              val (up, down) = (vector.copy, vector)
    +              up.update(dim, up(dim) - width(dim) / 4)
    +              down.update(dim, down(dim) + width(dim) / 4)
    +
    +              Seq(up,down)
    +          }
    +          partitionHelper(newBox, dim + 1)
    +        }
    +      }
    +      partitionHelper(Seq(center), 0)
    +    }
    +  }
    +
    +
    +  val root = new Node( ((minVec.asBreeze + 
maxVec.asBreeze)*0.5).fromBreeze,
    +    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
    +
    +    /**
    +     * Simple printing of tree for testing/debugging
    +     */
    +  def printTree(): Unit = {
    +    printTreeRecur(root)
    +  }
    +
    +  def printTreeRecur(node: Node){
    +    if(node.children != null) {
    +      for (c <- node.children){
    +        printTreeRecur(c)
    +      }
    +    }else{
    +      println("printing tree: n.nodeElements " + node.nodeElements)
    +    }
    +  }
    +
    +  /**
    +   * Recursively adds an object to the tree
    +   * @param queryPoint
    +   */
    +  def insert(queryPoint: Vector){
    +    insertRecur(queryPoint,root)
    +  }
    +
    +  private def insertRecur(queryPoint: Vector,node: Node) {
    +    if (node.children == null) {
    +      if (node.nodeElements.length < maxPerBox ) {
    +        node.nodeElements += queryPoint
    +      } else{
    +        node.makeChildren()
    +        for (o <- node.nodeElements){
    +          insertRecur(o, node.children(node.whichChild(o)))
    +        }
    +        node.nodeElements.clear()
    +        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
    +      }
    +    } else{
    --- End diff --
    
    done


> Add exact k-nearest-neighbours algorithm to machine learning library
> --------------------------------------------------------------------
>
>                 Key: FLINK-1745
>                 URL: https://issues.apache.org/jira/browse/FLINK-1745
>             Project: Flink
>          Issue Type: New Feature
>          Components: Machine Learning Library
>            Reporter: Till Rohrmann
>            Assignee: Daniel Blazevski
>              Labels: ML, Starter
>
> Even though the k-nearest-neighbours (kNN) [1,2] algorithm is quite trivial 
> it is still used as a mean to classify data and to do regression. This issue 
> focuses on the implementation of an exact kNN (H-BNLJ, H-BRJ) algorithm as 
> proposed in [2].
> Could be a starter task.
> Resources:
> [1] [http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm]
> [2] [https://www.cs.utah.edu/~lifeifei/papers/mrknnj.pdf]



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

Reply via email to