[
https://issues.apache.org/jira/browse/FLINK-1719?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14908209#comment-14908209
]
ASF GitHub Bot commented on FLINK-1719:
---------------------------------------
Github user sachingoel0101 commented on a diff in the pull request:
https://github.com/apache/flink/pull/1156#discussion_r40444429
--- Diff:
flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/MultinomialNaiveBayes.scala
---
@@ -0,0 +1,900 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification
+
+import java.{lang, util}
+
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.scala._
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.ml.common.{ParameterMap, Parameter}
+import org.apache.flink.ml.pipeline.{PredictDataSetOperation,
FitOperation, Predictor}
+import org.apache.flink.util.Collector
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.Map
+
+/**
+ * While building the model different approaches need to be compared.
+ * For that purpose the fitParameters are used. Every possibility that
might enhance
+ * the implementation can be chosen separately by using the following list
of parameters:
+ *
+ * Possibility 1: way of calculating document count
+ * P1 = 0 -> use .count() to get count of all documents
+ * P1 = 1 -> use a reducer and a mapper to create a broadcast data set
containing the count of
+ * all documents
+ *
+ * Possibility 2: all words in class (order of operators)
+ * If p2 = 1 improves the speed, many other calculations must switch
their operators, too.
+ * P2 = 0 -> first the reducer, than the mapper
+ * P2 = 1 -> first the mapper, than the reducer
+ *
+ * Possibility 3: way of calculating pwc
+ * P2 = 0 -> join singleWordsInClass and allWordsInClass to wordsInClass
data set
+ * P3 = 1 -> work on singleWordsInClass data set and broadcast
allWordsInClass data set
+ *
+ * Schneider/Rennie 1: ignore/reduce word frequency information
+ * SR1 = 0 -> word frequency information is not ignored
+ * SR1 = 1 -> word frequency information is ignored (Schneiders approach)
+ * SR1 = 2 -> word frequency information is reduced (Rennies approach)
+ *
+ * Schneider1: ignore P(c_j) in cMAP formula
+ * S1 = 0 -> normal cMAP formula
+ * S2 = 1 -> cMAP without P(c_j)
+ *
+ * Rennie1: transform document frequency
+ * R1 = 0 -> normal formula
+ * R1 = 1 -> apply inverse document frequecy
+ * Note: if R1 = 1 and SR1 = 2, both approaches get applied.
+ *
+ */
+class MultinomialNaiveBayes extends Predictor[MultinomialNaiveBayes] {
+
+ import MultinomialNaiveBayes._
+
+ //The model, that stores all needed information that are related to one
specific word
+ var wordRelatedModelData: Option[DataSet[(String, String, Double)]] =
+ None // (class name -> word -> log P(w|c))
+
+ //The model, that stores all needed information that are related to one
specifc class+
+ var classRelatedModelData: Option[DataSet[(String, Double, Double)]] =
+ None // (class name -> p(c) -> log p(w|c) not in class)
+
+ //A data set that stores additional needed information for some of the
improvements
+ var improvementData: Option[DataSet[(String, Double)]] =
+ None // (word -> log number of documents in all classes / word
frequency in all classes
+
+ // ============================== Parameter configuration
========================================
+
+ def setP1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P1, value)
+ this
+ }
+
+ def setP2(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P2, value)
+ this
+ }
+
+ def setP3(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P3, value)
+ this
+ }
+
+ def setSR1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(SR1, value)
+ this
+ }
+
+ def setS1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(S1, value)
+ this
+ }
+
+ def setR1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(R1, value)
+ this
+ }
+
+ // =============================================== Methods
=======================================
+
+ /**
+ * Save already existing model data created by the NaiveBayes algorithm.
Requires the designated
+ * locations. The saved data is a representation of the
[[wordRelatedModelData]] and
+ * [[classRelatedModelData]].
+ * @param wordRelated, the save location for the wordRelated data
+ * @param classRelated, the save location for the classRelated data
+ */
+ def saveModelDataSet(wordRelated: String, classRelated: String) : Unit =
{
+ wordRelatedModelData.get.writeAsCsv(wordRelated, "\n", "|",
WriteMode.OVERWRITE)
+ classRelatedModelData.get.writeAsCsv(classRelated, "\n", "|",
WriteMode.OVERWRITE)
+ }
+
+ /**
+ * Save the improvment data set. Requires the designated save location.
The saved data is a
+ * representation of the [[improvementData]] data set.
+ * @param path, the save location for the improvment data
+ */
+ def saveImprovementDataSet(path: String) : Unit = {
+ improvementData.get.writeAsCsv(path, "\n", "|", WriteMode.OVERWRITE)
+ }
+
+ /**
+ * Sets the [[wordRelatedModelData]] and the [[classRelatedModelData]]
to the given data sets.
+ * @param wordRelated, the data set representing the wordRelated model
+ * @param classRelated, the data set representing the classRelated model
+ */
+ def setModelDataSet(wordRelated : DataSet[(String, String, Double)],
+ classRelated: DataSet[(String, Double, Double)]) :
Unit = {
+ this.wordRelatedModelData = Some(wordRelated)
+ this.classRelatedModelData = Some(classRelated)
+ }
+
+ def setImprovementDataSet(impSet : DataSet[(String, Double)]) : Unit = {
+ this.improvementData = Some(impSet)
+ }
+
+}
+
+object MultinomialNaiveBayes {
+
+ // ========================================== Parameters
=========================================
+ case object P1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object P2 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object P3 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object SR1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object S1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object R1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ // ======================================== Factory Methods
======================================
+ def apply(): MultinomialNaiveBayes = {
+ new MultinomialNaiveBayes()
+ }
+
+ // ====================================== Operations
=============================================
+ /**
+ * Trains the models to fit the training data. The resulting
+ * [[MultinomialNaiveBayes.wordRelatedModelData]] and
+ * [[MultinomialNaiveBayes.classRelatedModelData]] are stored in the
[[MultinomialNaiveBayes]]
+ * instance.
+ */
+
+ implicit val fitNNB = new FitOperation[MultinomialNaiveBayes, (String,
String)] {
+ /**
+ * The [[FitOperation]] used to create the model. Requires an instance
of
+ * [[MultinomialNaiveBayes]], a [[ParameterMap]] and the input data
set. This data set
+ * maps (string -> string) containing (label -> text, words separated
by ",")
+ * @param instance of [[MultinomialNaiveBayes]]
+ * @param fitParameters, additional parameters
+ * @param input, the to processed data set
+ */
+ override def fit(instance: MultinomialNaiveBayes,
+ fitParameters: ParameterMap,
+ input: DataSet[(String, String)]): Unit = {
+
+ val resultingParameters = instance.parameters ++ fitParameters
+
+ //Count the amount of documents for each class.
+ // 1. Map: replace the document text by a 1
+ // 2. Group-Reduce: sum the 1s by class
+ val documentsPerClass: DataSet[(String, Int)] = input.map { input =>
(input._1, 1)}
+ .groupBy(0).sum(1) // (class name -> count of documents)
+
+ //Count the amount of occurrences of each word for each class.
+ // 1. FlatMap: split the document into its words and add a 1 to each
tuple
+ // 2. Group-Reduce: sum the 1s by class, word
+ var singleWordsInClass: DataSet[(String, String, Int)] = input
+ .flatMap(new SingleWordSplitter())
+ .groupBy(0, 1).sum(2) // (class name -> word -> count of that word)
+
+ //POSSIBILITY 2: all words in class (order of operators)
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //the allWordsInClass data set does only contain distinct
+ //words for schneiders approach: ndw(cj), nothing changes for
rennies approach
+
+ val p2 = resultingParameters(P2)
+
+ val sr1 = resultingParameters(SR1)
+
+ var allWordsInClass: DataSet[(String, Int)] =
+ null // (class name -> count of all words in that class)
+
+ if (p2 == 0) {
+ if (sr1 == 0 || sr1 == 2) {
+ //Count all the words for each class.
+ // 1. Reduce: add the count for each word in a class together
+ // 2. Map: remove the field that contains the word
+ allWordsInClass = singleWordsInClass.groupBy(0).reduce {
+ (singleWords1, singleWords2) =>
+ (singleWords1._1, singleWords1._2, singleWords1._3 +
singleWords2._3)
+ }.map(singleWords =>
+ (singleWords._1, singleWords._3)) // (class name -> count of
all words in that class)
+ } else if (sr1 == 1) {
+ //Count all distinct words for each class.
+ // 1. Map: set the word count to 1
+ // 2. Reduce: add the count for each word in a class together
+ // 3. Map: remove the field that contains the word
+ allWordsInClass = singleWordsInClass
+ .map(singleWords => (singleWords._1, singleWords._2, 1))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) =>
+ (singleWords1._1, singleWords1._2, singleWords1._3 +
singleWords2._3)
+ }.map(singleWords =>
+ (singleWords._1, singleWords._3))//(class name -> count of
distinct words in that class)
+ }
+ } else if (p2 == 1) {
+ if (sr1 == 0 || sr1 == 2) {
+ //Count all the words for each class.
+ // 1. Map: remove the field that contains the word
+ // 2. Reduce: add the count for each word in a class together
+ allWordsInClass = singleWordsInClass.map(singleWords =>
(singleWords._1, singleWords._3))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) => (singleWords1._1,
singleWords1._2 + singleWords2._2)
+ } // (class name -> count of all words in that class)
+ } else if (sr1 == 1) {
+ //Count all distinct words for each class.
+ // 1. Map: remove the field that contains the word, set the word
count to 1
+ // 2. Reduce: add the count for each word in a class together
+ allWordsInClass = singleWordsInClass.map(singleWords =>
(singleWords._1, 1))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) => (singleWords1._1,
singleWords1._2 + singleWords2._2)
+ } // (class name -> count of distinct words in that class)
+ }
+
+ }
+
+ //END SCHNEIDER/RENNIE 1
+ //END POSSIBILITY 2
+
+ //POSSIBILITY 1: way of calculating document count
+ val p1 = resultingParameters(P1)
+
+ var pc: DataSet[(String, Double)] = null // (class name -> P(c) in
class)
+
+ if (p1 == 0) {
+ val documentsCount: Double = input.count() //count of all documents
+ //Calculate P(c)
+ // 1. Map: divide count of documents for a class through total
count of documents
+ pc = documentsPerClass.map(line => (line._1, line._2 /
documentsCount))
+
+ } else if (p1 == 1) {
+ //Create a data set that contains only one double value: the count
of all documents
+ // 1. Reduce: At the count of documents together
+ // 2. Map: Remove field that contains document identifier
+ val documentCount: DataSet[(Double)] = documentsPerClass
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ .map(line => line._2) //(count of all documents)
+
+ //calculate P(c)
+ // 1. Map: divide count of documents for a class through total
count of documents
+ // (only element in documentCount data set)
+ pc = documentsPerClass.map(new RichMapFunction[(String, Int),
(String, Double)] {
+
+ var broadcastSet: util.List[Double] = null
+
+ override def open(config: Configuration): Unit = {
+ broadcastSet =
getRuntimeContext.getBroadcastVariable[Double]("documentCount")
+ if (broadcastSet.size() != 1) {
+ throw new RuntimeException("The document count data set
used by p1 = 1 has the " +
+ "wrong size! Please use p1 = 0 if the problem can not be
solved.")
+ }
+ }
+
+ override def map(value: (String, Int)): (String, Double) = {
+ (value._1, value._2 / broadcastSet.get(0))
+ }
+ }).withBroadcastSet(documentCount, "documentCount")
+ }
+ //END POSSIBILITY 1
+
+ // (list of all words, but distinct)
+ val vocabulary = singleWordsInClass.map(tuple => (tuple._2,
1)).distinct(0)
+ // (count of items in vocabulary list)
+ val vocabularyCount: Double = vocabulary.count()
+
+ //calculate the P(w|c) value for words, that are not part of a
class, needed for smoothing
+ // 1. Map: use P(w|c) formula with smoothing with n(c_j, w_t) = 0
+ val pwcNotInClass: DataSet[(String, Double)] = allWordsInClass
+ .map(line =>
+ (line._1, 1 / (line._2 + vocabularyCount))) // (class name ->
P(w|c) word not in class)
+
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //The singleWordsInClass data set must be changed before, the
calculation of pwc starts for
+ //schneider, it needs this form classname -> word -> number of
documents containing wt in cj
+
+ if (sr1 == 1) {
+ //Calculate the required data set (see above)
+ // 1. FlatMap: class -> word -> 1 (one tuple for each document in
which this word occurs)
+ // 2. Group-Reduce: sum all 1s where the first two fields equal
+ // 3. Map: Remove unesseccary count of word and replace with 1
+ singleWordsInClass = input
+ .flatMap(new SingleDistinctWordSplitter())
+ .groupBy(0, 1)
+ .reduce((line1, line2) => (line1._1, line1._2, line1._3 +
line2._3))
+ }
+
+ //END SCHNEIDER/RENNIE 1
+
+ //POSSIBILITY 3: way of calculating pwc
+
+ val p3 = resultingParameters(P3)
+
+ var pwc: DataSet[(String, String, Double)] = null // (class name ->
word -> P(w|c))
+
+ if (p3 == 0) {
+
+ //Join the singleWordsInClass data set with the allWordsInClass
data set to use the
+ //information for the calculation of p(w|c).
+ val wordsInClass = singleWordsInClass
+ .join(allWordsInClass).where(0).equalTo(0) {
+ (single, all) => (single._1, single._2, single._3, all._2)
+ } // (class name -> word -> count of that word -> count of all
words in that class)
+
+ //calculate the P(w|c) value for each word in each class
+ // 1. Map: use normal P(w|c) formula
+ pwc = wordsInClass.map(line => (line._1, line._2, (line._3 + 1) /
+ (line._4 + vocabularyCount)))
+
+ } else if (p3 == 1) {
+
+ //calculate the P(w|c) value for each word in class
+ // 1. Map: use normal P(w|c) formula / use the
+ pwc = singleWordsInClass.map(new RichMapFunction[(String, String,
Int),
+ (String, String, Double)] {
+
+ var broadcastMap: mutable.Map[String, Int] = mutable.Map[String,
Int]()
+
+
+ override def open(config: Configuration): Unit = {
+ val collection = getRuntimeContext
+ .getBroadcastVariable[(String,
Int)]("allWordsInClass").asScala
+ for (record <- collection) {
+ broadcastMap.put(record._1, record._2)
+ }
+ }
+
+ override def map(value: (String, String, Int)): (String, String,
Double) = {
+ (value._1, value._2, (value._3 + 1) / (broadcastMap(value._1)
+ vocabularyCount))
+ }
+ }).withBroadcastSet(allWordsInClass, "allWordsInClass")
+
+ }
+
+ //END POSSIBILITY 3
+
+ //stores all the word related information in one data set
+ // 1. Map: Caluclate logarithms
+ val wordRelatedModelData = pwc.map(line => (line._1, line._2,
Math.log(line._3)))
+
+ //store all class related information in one data set
+ // 1. Join: P(c) data set and P(w|c) data set not in class and
calculate logarithms
+ val classRelatedModelData = pc.join(pwcNotInClass)
+ .where(0).equalTo(0) {
+ (line1, line2) => (line1._1, Math.log(line1._2),
Math.log(line2._2))
+ } // (class name -> log(P(c)) -> log(P(w|c) not in class))
+
+ instance.wordRelatedModelData = Some(wordRelatedModelData)
+ instance.classRelatedModelData = Some(classRelatedModelData)
+
+ //RENNIE 1: transform document frequency
+ //for this, the improvementData set must be set
+ //calculate (word -> log number of documents in all classes / docs
with that word)
+
+ val r1 = resultingParameters(R1)
+
+ if (r1 == 1) {
+ val totalDocumentCount: DataSet[(Double)] = documentsPerClass
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ .map(line => line._2) // (count of all documents)
+
+ //number of occurences over all documents of all classes
+ val wordCountTotal = input
+ .flatMap(new SingleDistinctWordSplitter())
+ .map(line => (line._2, 1))
+ .groupBy(0)
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ // (word -> count of documents with that word)
+
+ val improvementData = wordCountTotal.map(new
RichMapFunction[(String, Int),
+ (String, Double)] {
+
+ var broadcastSet: util.List[Double] = null
+
+ override def open(config: Configuration): Unit = {
+ broadcastSet =
getRuntimeContext.getBroadcastVariable[Double]("totalDocumentCount")
+ if (broadcastSet.size() != 1) {
+ throw new RuntimeException("The total document count data
set used by 11 = 1 has " +
+ "the wrong size! Please use r1 = 0 if the problem can not
be solved.")
+ }
+ }
+
+ override def map(value: (String, Int)): (String, Double) = {
+ (value._1, Math.log(broadcastSet.get(0) / value._2))
+ }
+ }).withBroadcastSet(totalDocumentCount, "totalDocumentCount")
+
+ instance.improvementData = Some(improvementData)
+ }
+
+ }
+ }
+
+ // Model (String, String, Double, Double, Double)
+ implicit def predictNNB = new PredictDataSetOperation[
+ MultinomialNaiveBayes,
+ (Int, String),
+ (Int, String)]() {
+
+ override def predictDataSet(instance: MultinomialNaiveBayes,
+ predictParameters: ParameterMap,
+ input: DataSet[(Int, String)]):
DataSet[(Int, String)] = {
+
+ if (instance.wordRelatedModelData.isEmpty ||
instance.classRelatedModelData.isEmpty) {
+ throw new RuntimeException("The NormalNaiveBayes has not been
fitted to the " +
+ "data. This is necessary before a prediction on other data can
be made.")
+ }
+
+ val wordRelatedModelData = instance.wordRelatedModelData.get
+ val classRelatedModelData = instance.classRelatedModelData.get
+
+ val resultingParameters = instance.parameters ++ predictParameters
+
+ //split the texts from the input data set into its words
+ val words: DataSet[(Int, String)] = input.flatMap {
+ pair => pair._2.split(" ").map { word => (pair._1, word)}
+ }
+
+ //genreate word counts for each word with a key
+ // 1. Map: put a 1 to each key
+ // 2. Group-Reduce: group by id and word and sum the 1s
+ val wordsAndCount: DataSet[(Int, String, Int)] = words.map(line =>
(line._1, line._2, 1))
+ .groupBy(0, 1).sum(2) // (id -> word -> word count in text)
+
+ //calculate the count of all words for a text identified by its key
+ val wordsInText: DataSet[(Int, Int)] = wordsAndCount.map(line =>
(line._1, line._3))
+ .groupBy(0).sum(1) //(id -> all words in text)
+
+ //generate a data set containing all words that are in model for
each id, class pair
+ // 1. Join: wordRelatedModelData with wordsAndCount on
+ // words (id -> class -> word -> word count -> log(P(w|c))
+ val foundWords: DataSet[(Int, String, String, Int, Double)] =
wordRelatedModelData
+ .joinWithHuge(wordsAndCount).where(1).equalTo(1) {
+ (wordR, wordsAC) => (wordsAC._1, wordR._1, wordsAC._2, wordsAC._3,
wordR._3)
+ }
+
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //RENNIE 1: transform document frequency
+
+ val sr1 = resultingParameters(SR1)
+ val r1 = resultingParameters(R1)
+ var improvementData: DataSet[(String, Double)] = null
+
+ if (r1 == 1) {
+ //The improvementData data set is needed
+ if (instance.improvementData.isEmpty) {
+ throw new RuntimeException("R1 = 1, for that additional data is
needed, but it was not" +
+ "found. Make sure to set R1 = 1 when fitting the training
data.")
+ }
+ improvementData = instance.improvementData.get
+ }
+
+ if (sr1 == 1 && r1 == 1) {
+ throw new RuntimeException("Parameter sr1 and r1 are both set to
1, which is not allowed.")
+ }
+
+ var sumPwcFoundWords: DataSet[(Int, String, Double)] = null
+
+ if (sr1 == 0 && r1 == 0) {
+ //calculate sumpwc for found words
+ // 1. Map: Remove unneded information from foundWords and
calculate the sumP(w|c) for each
+ // word (id -> class -> word count * log(P(w|c))
+ // 2. Group-Reduce: on id and class, sum all (word count *
log(P(w|c))) results
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, line._4 * line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 1 && r1 == 0) {
+ //same as sr1 == 0, but there is no multiplication with the word
counts
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 2 && r1 == 0) {
+ //same es sr1 == 0, but multiplication with log(wordcount + 1)
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, Math.log(line._4 + 1) * line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 0 && r1 == 1) {
+ //same as r1 = 0, but the word frequency is multiplied with with
log (n_d(c) / n_d(w_t))
+ //for that a join with the improvementData data set must be
performed first to get the
+ //needed additional information.
+ // Join: (id -> class -> word count * log P(w|c) * log (number of
documents in class /
+ // word
frequency in all classes)
+ sumPwcFoundWords = foundWords
+ .joinWithTiny(improvementData).where(2).equalTo(0) {
+ (found, imp) => (found._1, found._2, found._4 * found._5 *
imp._2)
+ }.groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c)) */
+ } else if (sr1 == 2 && r1 == 1) {
+ //combination of r1 = 1 and sr1 =2
+ sumPwcFoundWords = foundWords
+ .joinWithTiny(improvementData).where(2).equalTo(0) {
+ (found, imp) => (found._1, found._2, Math.log(found._4 + 1) *
found._5 * imp._2)
+ }.groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c)) */
+ }
+
+ var sumPwcNotFoundWords: DataSet[(Int, String, Double)] = null
+
+ if (sr1 == 0 && r1 == 0) {
+ //calculate sumwpc for words that are not in model in that class
+ // 1. Map: Discard word and log(P(w|c) from foundWords
+ // 2. Group-Reduce: calculate sum count of found words for each
document,
+ // class pair (id -> class -> sum(wordCount))
+ // 3. Join: with wordsInText on id, to get the count of all words
per document
+ // 4. Map: calculate sumPWcNotFound (id -> class ->
+ // (all words in document - found word in document) *
+ // log(P(w|c) not in class (provided by broadcast))
+ sumPwcNotFoundWords = foundWords
+ .map(line => (line._1, line._2, line._4))
+ .groupBy(0, 1)
+ .reduce((line1, line2) => (line1._1, line1._2, line1._3 +
line2._3))
+ .join(wordsInText).where(0).equalTo(0) {
+ (foundW, wordsIT) => (foundW._1, foundW._2, foundW._3,
wordsIT._2)
+ }.map(new RichMapFunction[(Int, String, Int, Int), (Int, String,
Double)] {
+
+ var broadcastMap: mutable.Map[String, Double] = mutable
+ .Map[String, Double]() //class -> log(P(w|c) not found word in
class)
+
+ override def open(config: Configuration): Unit = {
+ val collection = getRuntimeContext
+ .getBroadcastVariable[(String, Double,
Double)]("classRelatedModelData")
+ .asScala
+ for (record <- collection) {
+ broadcastMap.put(record._1, record._3)
+ }
+ }
+
+ override def map(value: (Int, String, Int, Int)): (Int, String,
Double) = {
+ (value._1, value._2, (value._4 - value._3) *
broadcastMap(value._2))
+ }
+ }).withBroadcastSet(classRelatedModelData, "classRelatedModelData")
+
+ } else {
+ /** if the word frequency is changed (as in SR1 = 1, SR1 = 2, R1 =
1), the
+ * sumPwcNotFoundWords can not be calculated as above. They must
be calculated the same way
+ * as for the sumPwcFoundWords data set (because each word
frequency for each word is
+ * important).
+ *
+ * Prepare foundWords data set for differenz
+ * 1. Map: discard log(P(w|c) and wordCount from foundWords (id
-> class -> word)
+ * = all words that are known in each document for each class
+ *
+ * Prepare wordsAndCounts data set for difference
+ * 1. FlatMap:create for every tuple a (document id -> class ->
word -> word count)
+ * tuple for every class
+ *
+ * Create notFoundWords data set
+ * 1: CoGroup: Only tuples (id, class, word) that are not in
preparedFoundWords
+ *
+ * The result is a data set, that contains all words for each
document for each class
+ * that are not part of that class and the word counts.
+ *
+ * Then calcualte sumPwcNotfoundWords
+ */
+
+ val preparedFoundWords: DataSet[(Int, String, String)] = foundWords
+ .map(line => (line._1, line._2, line._3))
+
+ val wordsAndCountsExtended: DataSet[(Int, String, String, Int)] =
wordsAndCount
+ .flatMap(new Extender())
+ .withBroadcastSet(classRelatedModelData, "classRelatedModelData")
+
+ val notFoundWords: DataSet[(Int, String, String, Int)] =
wordsAndCountsExtended
+ .coGroup(preparedFoundWords).where(0, 1, 2)
+ .equalTo(0, 1, 2)(new DifferenceCoGrouper) //(id -> class ->
word -> word count)
+
+ if ((sr1 == 1 && r1 == 0) || (sr1 == 2 && r1 == 0)) {
+ //calculate the sum of all Pwc for every not found word (id ->
class -> sumPwcNotFound)
+ // 1. Map: calculate the pwc value for every word (id -> class
-> pwc)
+ // 2. Sum: Sum these pwc values for each class and document
+ sumPwcNotFoundWords = notFoundWords
+ .map(new RichMapFunction[(Int, String, String, Int), (Int,
String, Double)] {
+
+ var broadcastMap: mutable.Map[String, Double] = mutable
+ .Map[String, Double]() //class -> log(P(w|c) not found word
in class)
+
+ override def open(config: Configuration): Unit = {
+ val collection = getRuntimeContext
+ .getBroadcastVariable[(String, Double,
Double)]("classRelatedModelData")
+ .asScala
+ for (record <- collection) {
+ broadcastMap.put(record._1, record._3)
+ }
+ }
+
+ override def map(value: (Int, String, String, Int)): (Int,
String, Double) = {
+ if (sr1 == 1 && r1 == 0) {
+ //same as sr1 == 0, but there is no multiplication with
the word counts
+ return (value._1, value._2, broadcastMap(value._2))
+ } else if (sr1 == 2 && r1 == 0) {
+ //same es sr1 == 0, but multiplication with log(wordcount
+ 1)
+ return (value._1, value._2, Math.log(value._4 + 1) *
broadcastMap(value._2))
+ }
+ throw new RuntimeException("sumPwcNotFound could not be
calculated because you" +
+ "choosed a not allowed parameter combination.")
+ }
+ }).withBroadcastSet(classRelatedModelData,
"classRelatedModelData")
+ .groupBy(0, 1).sum(2)
+ } else if ((sr1 == 0 && r1 == 1) || (sr1 == 2 && r1 == 1)) {
+
+ //same as r1 = 0, but the word frequency is multiplied with
+ //log (n_d(c) / n_d(w_t)) for that a join with the
improvementData data set must be
+ //performed first to get the needed additional information.
+ // 1. Join with improvement data: (id -> class -> word count ->
+ // log (number of documents in class / word frequency in all
classes))
+ // 2a. Map: (id -> class -> word count * log (see above) *
log(P(w|c))
+ // 2b. Map: (id -> class -> log(word count + 1) * log (see
above) * log(P(w|c))
+ sumPwcNotFoundWords = notFoundWords
+ .join(improvementData).where(2).equalTo(0) {
+ (nf, imp) => (nf._1, nf._2, nf._4, imp._2)
+ }.map(new RichMapFunction[(Int, String, Int, Double), (Int,
String, Double)] {
+
+ var broadcastMap: mutable.Map[String, Double] = mutable
+ .Map[String, Double]() //class -> log(P(w|c) not found word
in class)
+
+ override def open(config: Configuration): Unit = {
--- End diff --
same as above.
> Add naive Bayes classification algorithm to machine learning library
> --------------------------------------------------------------------
>
> Key: FLINK-1719
> URL: https://issues.apache.org/jira/browse/FLINK-1719
> Project: Flink
> Issue Type: New Feature
> Components: Machine Learning Library
> Reporter: Till Rohrmann
> Assignee: Jonathan Hasenburg
> Labels: ML
>
> Add naive Bayes algorithm to Flink's machine learning library as a basic
> classification algorithm. Maybe we can incorporate some of the improvements
> developed by [Karl-Michael
> Schneider|http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.59.2085&rep=rep1&type=pdf],
> [Sang-Bum Kim et
> al.|http://ieeexplore.ieee.org/xpl/articleDetails.jsp?arnumber=1704799] or
> [Jason Rennie et
> al.|http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf] into the
> implementation.
--
This message was sent by Atlassian JIRA
(v6.3.4#6332)