Github user sachingoel0101 commented on a diff in the pull request:
https://github.com/apache/flink/pull/1156#discussion_r40444317
--- Diff:
flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/MultinomialNaiveBayes.scala
---
@@ -0,0 +1,900 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification
+
+import java.{lang, util}
+
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.scala._
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.ml.common.{ParameterMap, Parameter}
+import org.apache.flink.ml.pipeline.{PredictDataSetOperation,
FitOperation, Predictor}
+import org.apache.flink.util.Collector
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.Map
+
+/**
+ * While building the model different approaches need to be compared.
+ * For that purpose the fitParameters are used. Every possibility that
might enhance
+ * the implementation can be chosen separately by using the following list
of parameters:
+ *
+ * Possibility 1: way of calculating document count
+ * P1 = 0 -> use .count() to get count of all documents
+ * P1 = 1 -> use a reducer and a mapper to create a broadcast data set
containing the count of
+ * all documents
+ *
+ * Possibility 2: all words in class (order of operators)
+ * If p2 = 1 improves the speed, many other calculations must switch
their operators, too.
+ * P2 = 0 -> first the reducer, than the mapper
+ * P2 = 1 -> first the mapper, than the reducer
+ *
+ * Possibility 3: way of calculating pwc
+ * P2 = 0 -> join singleWordsInClass and allWordsInClass to wordsInClass
data set
+ * P3 = 1 -> work on singleWordsInClass data set and broadcast
allWordsInClass data set
+ *
+ * Schneider/Rennie 1: ignore/reduce word frequency information
+ * SR1 = 0 -> word frequency information is not ignored
+ * SR1 = 1 -> word frequency information is ignored (Schneiders approach)
+ * SR1 = 2 -> word frequency information is reduced (Rennies approach)
+ *
+ * Schneider1: ignore P(c_j) in cMAP formula
+ * S1 = 0 -> normal cMAP formula
+ * S2 = 1 -> cMAP without P(c_j)
+ *
+ * Rennie1: transform document frequency
+ * R1 = 0 -> normal formula
+ * R1 = 1 -> apply inverse document frequecy
+ * Note: if R1 = 1 and SR1 = 2, both approaches get applied.
+ *
+ */
+class MultinomialNaiveBayes extends Predictor[MultinomialNaiveBayes] {
+
+ import MultinomialNaiveBayes._
+
+ //The model, that stores all needed information that are related to one
specific word
+ var wordRelatedModelData: Option[DataSet[(String, String, Double)]] =
+ None // (class name -> word -> log P(w|c))
+
+ //The model, that stores all needed information that are related to one
specifc class+
+ var classRelatedModelData: Option[DataSet[(String, Double, Double)]] =
+ None // (class name -> p(c) -> log p(w|c) not in class)
+
+ //A data set that stores additional needed information for some of the
improvements
+ var improvementData: Option[DataSet[(String, Double)]] =
+ None // (word -> log number of documents in all classes / word
frequency in all classes
+
+ // ============================== Parameter configuration
========================================
+
+ def setP1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P1, value)
+ this
+ }
+
+ def setP2(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P2, value)
+ this
+ }
+
+ def setP3(value: Int): MultinomialNaiveBayes = {
+ parameters.add(P3, value)
+ this
+ }
+
+ def setSR1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(SR1, value)
+ this
+ }
+
+ def setS1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(S1, value)
+ this
+ }
+
+ def setR1(value: Int): MultinomialNaiveBayes = {
+ parameters.add(R1, value)
+ this
+ }
+
+ // =============================================== Methods
=======================================
+
+ /**
+ * Save already existing model data created by the NaiveBayes algorithm.
Requires the designated
+ * locations. The saved data is a representation of the
[[wordRelatedModelData]] and
+ * [[classRelatedModelData]].
+ * @param wordRelated, the save location for the wordRelated data
+ * @param classRelated, the save location for the classRelated data
+ */
+ def saveModelDataSet(wordRelated: String, classRelated: String) : Unit =
{
+ wordRelatedModelData.get.writeAsCsv(wordRelated, "\n", "|",
WriteMode.OVERWRITE)
+ classRelatedModelData.get.writeAsCsv(classRelated, "\n", "|",
WriteMode.OVERWRITE)
+ }
+
+ /**
+ * Save the improvment data set. Requires the designated save location.
The saved data is a
+ * representation of the [[improvementData]] data set.
+ * @param path, the save location for the improvment data
+ */
+ def saveImprovementDataSet(path: String) : Unit = {
+ improvementData.get.writeAsCsv(path, "\n", "|", WriteMode.OVERWRITE)
+ }
+
+ /**
+ * Sets the [[wordRelatedModelData]] and the [[classRelatedModelData]]
to the given data sets.
+ * @param wordRelated, the data set representing the wordRelated model
+ * @param classRelated, the data set representing the classRelated model
+ */
+ def setModelDataSet(wordRelated : DataSet[(String, String, Double)],
+ classRelated: DataSet[(String, Double, Double)]) :
Unit = {
+ this.wordRelatedModelData = Some(wordRelated)
+ this.classRelatedModelData = Some(classRelated)
+ }
+
+ def setImprovementDataSet(impSet : DataSet[(String, Double)]) : Unit = {
+ this.improvementData = Some(impSet)
+ }
+
+}
+
+object MultinomialNaiveBayes {
+
+ // ========================================== Parameters
=========================================
+ case object P1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object P2 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object P3 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object SR1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object S1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ case object R1 extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(0)
+ }
+
+ // ======================================== Factory Methods
======================================
+ def apply(): MultinomialNaiveBayes = {
+ new MultinomialNaiveBayes()
+ }
+
+ // ====================================== Operations
=============================================
+ /**
+ * Trains the models to fit the training data. The resulting
+ * [[MultinomialNaiveBayes.wordRelatedModelData]] and
+ * [[MultinomialNaiveBayes.classRelatedModelData]] are stored in the
[[MultinomialNaiveBayes]]
+ * instance.
+ */
+
+ implicit val fitNNB = new FitOperation[MultinomialNaiveBayes, (String,
String)] {
+ /**
+ * The [[FitOperation]] used to create the model. Requires an instance
of
+ * [[MultinomialNaiveBayes]], a [[ParameterMap]] and the input data
set. This data set
+ * maps (string -> string) containing (label -> text, words separated
by ",")
+ * @param instance of [[MultinomialNaiveBayes]]
+ * @param fitParameters, additional parameters
+ * @param input, the to processed data set
+ */
+ override def fit(instance: MultinomialNaiveBayes,
+ fitParameters: ParameterMap,
+ input: DataSet[(String, String)]): Unit = {
+
+ val resultingParameters = instance.parameters ++ fitParameters
+
+ //Count the amount of documents for each class.
+ // 1. Map: replace the document text by a 1
+ // 2. Group-Reduce: sum the 1s by class
+ val documentsPerClass: DataSet[(String, Int)] = input.map { input =>
(input._1, 1)}
+ .groupBy(0).sum(1) // (class name -> count of documents)
+
+ //Count the amount of occurrences of each word for each class.
+ // 1. FlatMap: split the document into its words and add a 1 to each
tuple
+ // 2. Group-Reduce: sum the 1s by class, word
+ var singleWordsInClass: DataSet[(String, String, Int)] = input
+ .flatMap(new SingleWordSplitter())
+ .groupBy(0, 1).sum(2) // (class name -> word -> count of that word)
+
+ //POSSIBILITY 2: all words in class (order of operators)
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //the allWordsInClass data set does only contain distinct
+ //words for schneiders approach: ndw(cj), nothing changes for
rennies approach
+
+ val p2 = resultingParameters(P2)
+
+ val sr1 = resultingParameters(SR1)
+
+ var allWordsInClass: DataSet[(String, Int)] =
+ null // (class name -> count of all words in that class)
+
+ if (p2 == 0) {
+ if (sr1 == 0 || sr1 == 2) {
+ //Count all the words for each class.
+ // 1. Reduce: add the count for each word in a class together
+ // 2. Map: remove the field that contains the word
+ allWordsInClass = singleWordsInClass.groupBy(0).reduce {
+ (singleWords1, singleWords2) =>
+ (singleWords1._1, singleWords1._2, singleWords1._3 +
singleWords2._3)
+ }.map(singleWords =>
+ (singleWords._1, singleWords._3)) // (class name -> count of
all words in that class)
+ } else if (sr1 == 1) {
+ //Count all distinct words for each class.
+ // 1. Map: set the word count to 1
+ // 2. Reduce: add the count for each word in a class together
+ // 3. Map: remove the field that contains the word
+ allWordsInClass = singleWordsInClass
+ .map(singleWords => (singleWords._1, singleWords._2, 1))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) =>
+ (singleWords1._1, singleWords1._2, singleWords1._3 +
singleWords2._3)
+ }.map(singleWords =>
+ (singleWords._1, singleWords._3))//(class name -> count of
distinct words in that class)
+ }
+ } else if (p2 == 1) {
+ if (sr1 == 0 || sr1 == 2) {
+ //Count all the words for each class.
+ // 1. Map: remove the field that contains the word
+ // 2. Reduce: add the count for each word in a class together
+ allWordsInClass = singleWordsInClass.map(singleWords =>
(singleWords._1, singleWords._3))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) => (singleWords1._1,
singleWords1._2 + singleWords2._2)
+ } // (class name -> count of all words in that class)
+ } else if (sr1 == 1) {
+ //Count all distinct words for each class.
+ // 1. Map: remove the field that contains the word, set the word
count to 1
+ // 2. Reduce: add the count for each word in a class together
+ allWordsInClass = singleWordsInClass.map(singleWords =>
(singleWords._1, 1))
+ .groupBy(0).reduce {
+ (singleWords1, singleWords2) => (singleWords1._1,
singleWords1._2 + singleWords2._2)
+ } // (class name -> count of distinct words in that class)
+ }
+
+ }
+
+ //END SCHNEIDER/RENNIE 1
+ //END POSSIBILITY 2
+
+ //POSSIBILITY 1: way of calculating document count
+ val p1 = resultingParameters(P1)
+
+ var pc: DataSet[(String, Double)] = null // (class name -> P(c) in
class)
+
+ if (p1 == 0) {
+ val documentsCount: Double = input.count() //count of all documents
+ //Calculate P(c)
+ // 1. Map: divide count of documents for a class through total
count of documents
+ pc = documentsPerClass.map(line => (line._1, line._2 /
documentsCount))
+
+ } else if (p1 == 1) {
+ //Create a data set that contains only one double value: the count
of all documents
+ // 1. Reduce: At the count of documents together
+ // 2. Map: Remove field that contains document identifier
+ val documentCount: DataSet[(Double)] = documentsPerClass
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ .map(line => line._2) //(count of all documents)
+
+ //calculate P(c)
+ // 1. Map: divide count of documents for a class through total
count of documents
+ // (only element in documentCount data set)
+ pc = documentsPerClass.map(new RichMapFunction[(String, Int),
(String, Double)] {
+
+ var broadcastSet: util.List[Double] = null
+
+ override def open(config: Configuration): Unit = {
+ broadcastSet =
getRuntimeContext.getBroadcastVariable[Double]("documentCount")
+ if (broadcastSet.size() != 1) {
+ throw new RuntimeException("The document count data set
used by p1 = 1 has the " +
+ "wrong size! Please use p1 = 0 if the problem can not be
solved.")
+ }
+ }
+
+ override def map(value: (String, Int)): (String, Double) = {
+ (value._1, value._2 / broadcastSet.get(0))
+ }
+ }).withBroadcastSet(documentCount, "documentCount")
+ }
+ //END POSSIBILITY 1
+
+ // (list of all words, but distinct)
+ val vocabulary = singleWordsInClass.map(tuple => (tuple._2,
1)).distinct(0)
+ // (count of items in vocabulary list)
+ val vocabularyCount: Double = vocabulary.count()
+
+ //calculate the P(w|c) value for words, that are not part of a
class, needed for smoothing
+ // 1. Map: use P(w|c) formula with smoothing with n(c_j, w_t) = 0
+ val pwcNotInClass: DataSet[(String, Double)] = allWordsInClass
+ .map(line =>
+ (line._1, 1 / (line._2 + vocabularyCount))) // (class name ->
P(w|c) word not in class)
+
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //The singleWordsInClass data set must be changed before, the
calculation of pwc starts for
+ //schneider, it needs this form classname -> word -> number of
documents containing wt in cj
+
+ if (sr1 == 1) {
+ //Calculate the required data set (see above)
+ // 1. FlatMap: class -> word -> 1 (one tuple for each document in
which this word occurs)
+ // 2. Group-Reduce: sum all 1s where the first two fields equal
+ // 3. Map: Remove unesseccary count of word and replace with 1
+ singleWordsInClass = input
+ .flatMap(new SingleDistinctWordSplitter())
+ .groupBy(0, 1)
+ .reduce((line1, line2) => (line1._1, line1._2, line1._3 +
line2._3))
+ }
+
+ //END SCHNEIDER/RENNIE 1
+
+ //POSSIBILITY 3: way of calculating pwc
+
+ val p3 = resultingParameters(P3)
+
+ var pwc: DataSet[(String, String, Double)] = null // (class name ->
word -> P(w|c))
+
+ if (p3 == 0) {
+
+ //Join the singleWordsInClass data set with the allWordsInClass
data set to use the
+ //information for the calculation of p(w|c).
+ val wordsInClass = singleWordsInClass
+ .join(allWordsInClass).where(0).equalTo(0) {
+ (single, all) => (single._1, single._2, single._3, all._2)
+ } // (class name -> word -> count of that word -> count of all
words in that class)
+
+ //calculate the P(w|c) value for each word in each class
+ // 1. Map: use normal P(w|c) formula
+ pwc = wordsInClass.map(line => (line._1, line._2, (line._3 + 1) /
+ (line._4 + vocabularyCount)))
+
+ } else if (p3 == 1) {
+
+ //calculate the P(w|c) value for each word in class
+ // 1. Map: use normal P(w|c) formula / use the
+ pwc = singleWordsInClass.map(new RichMapFunction[(String, String,
Int),
+ (String, String, Double)] {
+
+ var broadcastMap: mutable.Map[String, Int] = mutable.Map[String,
Int]()
+
+
+ override def open(config: Configuration): Unit = {
+ val collection = getRuntimeContext
+ .getBroadcastVariable[(String,
Int)]("allWordsInClass").asScala
+ for (record <- collection) {
+ broadcastMap.put(record._1, record._2)
+ }
+ }
+
+ override def map(value: (String, String, Int)): (String, String,
Double) = {
+ (value._1, value._2, (value._3 + 1) / (broadcastMap(value._1)
+ vocabularyCount))
+ }
+ }).withBroadcastSet(allWordsInClass, "allWordsInClass")
+
+ }
+
+ //END POSSIBILITY 3
+
+ //stores all the word related information in one data set
+ // 1. Map: Caluclate logarithms
+ val wordRelatedModelData = pwc.map(line => (line._1, line._2,
Math.log(line._3)))
+
+ //store all class related information in one data set
+ // 1. Join: P(c) data set and P(w|c) data set not in class and
calculate logarithms
+ val classRelatedModelData = pc.join(pwcNotInClass)
+ .where(0).equalTo(0) {
+ (line1, line2) => (line1._1, Math.log(line1._2),
Math.log(line2._2))
+ } // (class name -> log(P(c)) -> log(P(w|c) not in class))
+
+ instance.wordRelatedModelData = Some(wordRelatedModelData)
+ instance.classRelatedModelData = Some(classRelatedModelData)
+
+ //RENNIE 1: transform document frequency
+ //for this, the improvementData set must be set
+ //calculate (word -> log number of documents in all classes / docs
with that word)
+
+ val r1 = resultingParameters(R1)
+
+ if (r1 == 1) {
+ val totalDocumentCount: DataSet[(Double)] = documentsPerClass
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ .map(line => line._2) // (count of all documents)
+
+ //number of occurences over all documents of all classes
+ val wordCountTotal = input
+ .flatMap(new SingleDistinctWordSplitter())
+ .map(line => (line._2, 1))
+ .groupBy(0)
+ .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
+ // (word -> count of documents with that word)
+
+ val improvementData = wordCountTotal.map(new
RichMapFunction[(String, Int),
+ (String, Double)] {
+
+ var broadcastSet: util.List[Double] = null
+
+ override def open(config: Configuration): Unit = {
+ broadcastSet =
getRuntimeContext.getBroadcastVariable[Double]("totalDocumentCount")
+ if (broadcastSet.size() != 1) {
+ throw new RuntimeException("The total document count data
set used by 11 = 1 has " +
+ "the wrong size! Please use r1 = 0 if the problem can not
be solved.")
+ }
+ }
+
+ override def map(value: (String, Int)): (String, Double) = {
+ (value._1, Math.log(broadcastSet.get(0) / value._2))
+ }
+ }).withBroadcastSet(totalDocumentCount, "totalDocumentCount")
+
+ instance.improvementData = Some(improvementData)
+ }
+
+ }
+ }
+
+ // Model (String, String, Double, Double, Double)
+ implicit def predictNNB = new PredictDataSetOperation[
+ MultinomialNaiveBayes,
+ (Int, String),
+ (Int, String)]() {
+
+ override def predictDataSet(instance: MultinomialNaiveBayes,
+ predictParameters: ParameterMap,
+ input: DataSet[(Int, String)]):
DataSet[(Int, String)] = {
+
+ if (instance.wordRelatedModelData.isEmpty ||
instance.classRelatedModelData.isEmpty) {
+ throw new RuntimeException("The NormalNaiveBayes has not been
fitted to the " +
+ "data. This is necessary before a prediction on other data can
be made.")
+ }
+
+ val wordRelatedModelData = instance.wordRelatedModelData.get
+ val classRelatedModelData = instance.classRelatedModelData.get
+
+ val resultingParameters = instance.parameters ++ predictParameters
+
+ //split the texts from the input data set into its words
+ val words: DataSet[(Int, String)] = input.flatMap {
+ pair => pair._2.split(" ").map { word => (pair._1, word)}
+ }
+
+ //genreate word counts for each word with a key
+ // 1. Map: put a 1 to each key
+ // 2. Group-Reduce: group by id and word and sum the 1s
+ val wordsAndCount: DataSet[(Int, String, Int)] = words.map(line =>
(line._1, line._2, 1))
+ .groupBy(0, 1).sum(2) // (id -> word -> word count in text)
+
+ //calculate the count of all words for a text identified by its key
+ val wordsInText: DataSet[(Int, Int)] = wordsAndCount.map(line =>
(line._1, line._3))
+ .groupBy(0).sum(1) //(id -> all words in text)
+
+ //generate a data set containing all words that are in model for
each id, class pair
+ // 1. Join: wordRelatedModelData with wordsAndCount on
+ // words (id -> class -> word -> word count -> log(P(w|c))
+ val foundWords: DataSet[(Int, String, String, Int, Double)] =
wordRelatedModelData
+ .joinWithHuge(wordsAndCount).where(1).equalTo(1) {
+ (wordR, wordsAC) => (wordsAC._1, wordR._1, wordsAC._2, wordsAC._3,
wordR._3)
+ }
+
+ //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
+ //RENNIE 1: transform document frequency
+
+ val sr1 = resultingParameters(SR1)
+ val r1 = resultingParameters(R1)
+ var improvementData: DataSet[(String, Double)] = null
+
+ if (r1 == 1) {
+ //The improvementData data set is needed
+ if (instance.improvementData.isEmpty) {
+ throw new RuntimeException("R1 = 1, for that additional data is
needed, but it was not" +
+ "found. Make sure to set R1 = 1 when fitting the training
data.")
+ }
+ improvementData = instance.improvementData.get
+ }
+
+ if (sr1 == 1 && r1 == 1) {
+ throw new RuntimeException("Parameter sr1 and r1 are both set to
1, which is not allowed.")
+ }
+
+ var sumPwcFoundWords: DataSet[(Int, String, Double)] = null
+
+ if (sr1 == 0 && r1 == 0) {
+ //calculate sumpwc for found words
+ // 1. Map: Remove unneded information from foundWords and
calculate the sumP(w|c) for each
+ // word (id -> class -> word count * log(P(w|c))
+ // 2. Group-Reduce: on id and class, sum all (word count *
log(P(w|c))) results
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, line._4 * line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 1 && r1 == 0) {
+ //same as sr1 == 0, but there is no multiplication with the word
counts
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 2 && r1 == 0) {
+ //same es sr1 == 0, but multiplication with log(wordcount + 1)
+ sumPwcFoundWords = foundWords
+ .map(line => (line._1, line._2, Math.log(line._4 + 1) * line._5))
+ .groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c))
+ } else if (sr1 == 0 && r1 == 1) {
+ //same as r1 = 0, but the word frequency is multiplied with with
log (n_d(c) / n_d(w_t))
+ //for that a join with the improvementData data set must be
performed first to get the
+ //needed additional information.
+ // Join: (id -> class -> word count * log P(w|c) * log (number of
documents in class /
+ // word
frequency in all classes)
+ sumPwcFoundWords = foundWords
+ .joinWithTiny(improvementData).where(2).equalTo(0) {
+ (found, imp) => (found._1, found._2, found._4 * found._5 *
imp._2)
+ }.groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c)) */
+ } else if (sr1 == 2 && r1 == 1) {
+ //combination of r1 = 1 and sr1 =2
+ sumPwcFoundWords = foundWords
+ .joinWithTiny(improvementData).where(2).equalTo(0) {
+ (found, imp) => (found._1, found._2, Math.log(found._4 + 1) *
found._5 * imp._2)
+ }.groupBy(0, 1).reduce((line1, line2) =>
+ (line1._1, line1._2, line1._3 + line2._3)) //(id -> class ->
sum(log(P(w|c)) */
+ }
+
+ var sumPwcNotFoundWords: DataSet[(Int, String, Double)] = null
+
+ if (sr1 == 0 && r1 == 0) {
+ //calculate sumwpc for words that are not in model in that class
+ // 1. Map: Discard word and log(P(w|c) from foundWords
+ // 2. Group-Reduce: calculate sum count of found words for each
document,
+ // class pair (id -> class -> sum(wordCount))
+ // 3. Join: with wordsInText on id, to get the count of all words
per document
+ // 4. Map: calculate sumPWcNotFound (id -> class ->
+ // (all words in document - found word in document) *
+ // log(P(w|c) not in class (provided by broadcast))
+ sumPwcNotFoundWords = foundWords
+ .map(line => (line._1, line._2, line._4))
+ .groupBy(0, 1)
+ .reduce((line1, line2) => (line1._1, line1._2, line1._3 +
line2._3))
+ .join(wordsInText).where(0).equalTo(0) {
+ (foundW, wordsIT) => (foundW._1, foundW._2, foundW._3,
wordsIT._2)
+ }.map(new RichMapFunction[(Int, String, Int, Int), (Int, String,
Double)] {
+
+ var broadcastMap: mutable.Map[String, Double] = mutable
+ .Map[String, Double]() //class -> log(P(w|c) not found word in
class)
+
+ override def open(config: Configuration): Unit = {
+ val collection = getRuntimeContext
+ .getBroadcastVariable[(String, Double,
Double)]("classRelatedModelData")
+ .asScala
+ for (record <- collection) {
+ broadcastMap.put(record._1, record._3)
+ }
+ }
+
+ override def map(value: (Int, String, Int, Int)): (Int, String,
Double) = {
+ (value._1, value._2, (value._4 - value._3) *
broadcastMap(value._2))
+ }
+ }).withBroadcastSet(classRelatedModelData, "classRelatedModelData")
+
+ } else {
+ /** if the word frequency is changed (as in SR1 = 1, SR1 = 2, R1 =
1), the
--- End diff --
`/** */` isn't used inside functions. Use `//` instead.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---