[FLINK-1745] [ml] Add exact k-nearest-neighbor join
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/858ca14c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/858ca14c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/858ca14c Branch: refs/heads/master Commit: 858ca14cd2252fd384b86906256049845db9360e Parents: 1212b6d Author: Chiwan Park <[email protected]> Authored: Tue Jun 30 17:41:25 2015 +0900 Committer: Chiwan Park <[email protected]> Committed: Mon May 30 19:32:26 2016 +0900 ---------------------------------------------------------------------- .../main/scala/org/apache/flink/ml/nn/KNN.scala | 214 +++++++++++++++++++ .../org/apache/flink/ml/nn/KNNITSuite.scala | 68 ++++++ 2 files changed, 282 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/858ca14c/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala new file mode 100644 index 0000000..35073b6 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.nn + +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala.DataSetUtils._ +import org.apache.flink.api.scala._ +import org.apache.flink.ml.common._ +import org.apache.flink.ml.math.Vector +import org.apache.flink.ml.metrics.distances.{DistanceMetric, EuclideanDistanceMetric} +import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, Predictor} +import org.apache.flink.util.Collector + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** Implements a k-nearest neighbor join. + * + * Calculates the `k` nearest neighbor points in the training set for each point in the test set. + * + * @example + * {{{ + * val trainingDS: DataSet[Vector] = ... + * val testingDS: DataSet[Vector] = ... + * + * val knn = KNN() + * .setK(10) + * .setBlocks(5) + * .setDistanceMetric(EuclideanDistanceMetric()) + * + * knn.fit(trainingDS) + * + * val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.nn.KNN.K]] + * Sets the K which is the number of selected points as neighbors. (Default value: '''5''') + * + * - [[org.apache.flink.ml.nn.KNN.Blocks]] + * Sets the number of blocks into which the input data will be split. This number should be set + * at least to the degree of parallelism. If no value is specified, then the parallelism of the + * input [[DataSet]] is used as the number of blocks. (Default value: '''None''') + * + * - [[org.apache.flink.ml.nn.KNN.DistanceMetric]] + * Sets the distance metric we use to calculate the distance between two points. If no metric is + * specified, then [[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used. + * (Default value: '''EuclideanDistanceMetric()''') + * + */ +class KNN extends Predictor[KNN] { + + import KNN._ + + var trainingSet: Option[DataSet[Block[Vector]]] = None + + /** Sets K + * @param k the number of selected points as neighbors + */ + def setK(k: Int): KNN = { + require(k > 0, "K must be positive.") + parameters.add(K, k) + this + } + + /** Sets the distance metric + * @param metric the distance metric to calculate distance between two points + */ + def setDistanceMetric(metric: DistanceMetric): KNN = { + parameters.add(DistanceMetric, metric) + this + } + + /** Sets the number of data blocks/partitions + * @param n the number of data blocks + */ + def setBlocks(n: Int): KNN = { + require(n > 0, "Number of blocks must be positive.") + parameters.add(Blocks, n) + this + } +} + +object KNN { + + case object K extends Parameter[Int] { + val defaultValue: Option[Int] = Some(5) + } + + case object DistanceMetric extends Parameter[DistanceMetric] { + val defaultValue: Option[DistanceMetric] = Some(EuclideanDistanceMetric()) + } + + case object Blocks extends Parameter[Int] { + val defaultValue: Option[Int] = None + } + + def apply(): KNN = { + new KNN() + } + + /** [[FitOperation]] which trains a KNN based on the given training data set. + * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]] + */ + implicit def fitKNN[T <: Vector : TypeInformation] = new FitOperation[KNN, T] { + override def fit( + instance: KNN, + fitParameters: ParameterMap, + input: DataSet[T]): Unit = { + val resultParameters = instance.parameters ++ fitParameters + + require(resultParameters.get(K).isDefined, "K is needed for calculation") + + val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism) + val partitioner = FlinkMLTools.ModuloKeyPartitioner + val inputAsVector = input.asInstanceOf[DataSet[Vector]] + + instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, Some(partitioner))) + } + } + + /** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the given testing data + * set. + * @tparam T Subtype of [[Vector]] + * @return The given testing data set with k-nearest neighbors + */ + implicit def predictValues[T <: Vector : ClassTag : TypeInformation] = { + new PredictDataSetOperation[KNN, T, (Vector, Array[Vector])] { + override def predictDataSet( + instance: KNN, + predictParameters: ParameterMap, + input: DataSet[T]): DataSet[(Vector, Array[Vector])] = { + val resultParameters = instance.parameters ++ predictParameters + + instance.trainingSet match { + case Some(trainingSet) => + val k = resultParameters.get(K).get + val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism) + val metric = resultParameters.get(DistanceMetric).get + val partitioner = FlinkMLTools.ModuloKeyPartitioner + + // attach unique id for each data + val inputWithId: DataSet[(Long, T)] = input.zipWithUniqueId + + // split data into multiple blocks + val inputSplit = FlinkMLTools.block(inputWithId, blocks, Some(partitioner)) + + // join input and training set + val crossed = trainingSet.cross(inputSplit).mapPartition { + (iter, out: Collector[(Vector, Vector, Long, Double)]) => { + for ((training, testing) <- iter) { + val queue = mutable.PriorityQueue[(Vector, Vector, Long, Double)]()( + Ordering.by(_._4)) + + for (a <- testing.values; b <- training.values) { + // (training vector, input vector, input key, distance) + queue.enqueue((b, a._2, a._1, metric.distance(b, a._2))) + + if (queue.size > k) { + queue.dequeue() + } + } + + for (v <- queue) { + out.collect(v) + } + } + } + } + + // group by input vector id and pick k nearest neighbor for each group + val result = crossed.groupBy(2).sortGroup(3, Order.ASCENDING).reduceGroup { + (iter, out: Collector[(Vector, Array[Vector])]) => { + if (iter.hasNext) { + val head = iter.next() + val key = head._2 + val neighbors: ArrayBuffer[Vector] = ArrayBuffer(head._1) + + for ((vector, _, _, _) <- iter.take(k - 1)) { // we already took a first element + neighbors += vector + } + + out.collect(key, neighbors.toArray) + } + } + } + + result + case None => throw new RuntimeException("The KNN model has not been trained." + + "Call first fit before calling the predict operation.") + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/858ca14c/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala new file mode 100644 index 0000000..107724b --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.nn + +import org.apache.flink.api.scala._ +import org.apache.flink.ml.classification.Classification +import org.apache.flink.ml.math.DenseVector +import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric +import org.apache.flink.test.util.FlinkTestBase +import org.scalatest.{FlatSpec, Matchers} + +class KNNITSuite extends FlatSpec with Matchers with FlinkTestBase { + behavior of "The KNN Join Implementation" + + it should "throw an exception when the given K is not valid" in { + intercept[IllegalArgumentException] { + KNN().setK(0) + } + } + + it should "throw an exception when the given count of blocks is not valid" in { + intercept[IllegalArgumentException] { + KNN().setBlocks(0) + } + } + + it should "calculate kNN join correctly" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + // prepare data + val trainingSet = env.fromCollection(Classification.trainingData).map(_.vector) + val testingSet = env.fromElements(DenseVector(0.0, 0.0)) + + // calculate answer + val answer = Classification.trainingData.map { + v => (v.vector, SquaredEuclideanDistanceMetric().distance(DenseVector(0.0, 0.0), v.vector)) + }.sortBy(_._2).take(3).map(_._1).toArray + + val knn = KNN() + .setK(3) + .setBlocks(10) + .setDistanceMetric(SquaredEuclideanDistanceMetric()) + + // run knn join + knn.fit(trainingSet) + val result = knn.predict(testingSet).collect() + + result.size should be(1) + result.head._1 should be(DenseVector(0.0, 0.0)) + result.head._2 should be(answer) + } +}
