http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala deleted file mode 100644 index a943c5f..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.scalabindings - -import org.scalatest.FunSuite -import RLikeOps._ -import org.apache.mahout.test.MahoutSuite - -class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { - - test("multiplication") { - - val a = dense((1, 2, 3), (3, 4, 5)) - val b = dense(1, 4, 5) - val m = a %*% b - - assert(m(0, 0) == 24) - assert(m(1, 0) == 44) - println(m.toString) - } - - test("Hadamard") { - val a = dense( - (1, 2, 3), - (3, 4, 5) - ) - val b = dense( - (1, 1, 2), - (2, 1, 1) - ) - - val c = a * b - - printf("C=\n%s\n", c) - - assert(c(0, 0) == 1) - assert(c(1, 2) == 5) - println(c.toString) - - val d = a * 5.0 - assert(d(0, 0) == 5) - assert(d(1, 1) == 20) - - a *= b - assert(a(0, 0) == 1) - assert(a(1, 2) == 5) - println(a.toString) - - } - - /** Test dsl overloads over scala operations over matrices */ - test ("scalarOps") { - val a = dense( - (1, 2, 3), - (3, 4, 5) - ) - - (10 * a - (10 *: a)).norm shouldBe 0 - (10 + a - (10 +: a)).norm shouldBe 0 - (10 - a - (10 -: a)).norm shouldBe 0 - (10 / a - (10 /: a)).norm shouldBe 0 - - } - -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala deleted file mode 100644 index 832937b..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.scalabindings - -import org.scalatest.FunSuite -import org.apache.mahout.math.Vector -import RLikeOps._ -import org.apache.mahout.test.MahoutSuite - -class RLikeVectorOpsSuite extends FunSuite with MahoutSuite { - - test("Hadamard") { - val a: Vector = (1, 2, 3) - val b = (3, 4, 5) - - val c = a * b - println(c) - assert(c ===(3, 8, 15)) - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala deleted file mode 100644 index 037f562..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.scalabindings - -import org.scalatest.FunSuite -import org.apache.mahout.math.{RandomAccessSparseVector, Vector} -import RLikeOps._ -import org.apache.mahout.test.MahoutSuite - -/** VectorOps Suite */ -class VectorOpsSuite extends FunSuite with MahoutSuite { - - test("inline create") { - - val sparseVec = svec((5 -> 1) :: (10 -> 2.0) :: Nil) - println(sparseVec) - - val sparseVec2: Vector = (5 -> 1.0) :: (10 -> 2.0) :: Nil - println(sparseVec2) - - val sparseVec3: Vector = new RandomAccessSparseVector(100) := (5 -> 1.0) :: Nil - println(sparseVec3) - - val denseVec1: Vector = (1.0, 1.1, 1.2) - println(denseVec1) - - val denseVec2 = dvec(1, 0, 1.1, 1.2) - println(denseVec2) - } - - test("plus minus") { - - val a: Vector = (1, 2, 3) - val b: Vector = (0 -> 3) :: (1 -> 4) :: (2 -> 5) :: Nil - - val c = a + b - val d = b - a - val e = -b - a - - assert(c ===(4, 6, 8)) - assert(d ===(2, 2, 2)) - assert(e ===(-4, -6, -8)) - - } - - test("dot") { - - val a: Vector = (1, 2, 3) - val b = (3, 4, 5) - - val c = a dot b - println(c) - assert(c == 26) - - } - - test ("scalarOps") { - val a = dvec(1 to 5):Vector - - 10 * a shouldBe 10 *: a - 10 + a shouldBe 10 +: a - 10 - a shouldBe 10 -: a - 10 / a shouldBe 10 /: a - - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala b/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala deleted file mode 100644 index 3ec5ec1..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.nlp.tfidf - -import org.apache.mahout.math._ -import org.apache.mahout.math.scalabindings._ -import org.apache.mahout.test.DistributedMahoutSuite -import org.scalatest.{FunSuite, Matchers} -import scala.collection._ -import RLikeOps._ -import scala.math._ - - -trait TFIDFtestBase extends DistributedMahoutSuite with Matchers { - this: FunSuite => - - val epsilon = 1E-6 - - val documents: List[(Int, String)] = List( - (1, "the first document contains 5 terms"), - (2, "document two document contains 4 terms"), - (3, "document three three terms"), - (4, "each document including this document contain the term document")) - - def createDictionaryAndDfMaps(documents: List[(Int, String)]): (Map[String, Int], Map[Int, Int]) = { - - // get a tf count for the entire dictionary - val dictMap = documents.unzip._2.mkString(" ").toLowerCase.split(" ").groupBy(identity).mapValues(_.length) - - // create a dictionary with an index for each term - val dictIndex = dictMap.zipWithIndex.map(x => x._1._1 -> x._2).toMap - - val docFrequencyCount = new Array[Int](dictMap.size) - - for (token <- dictMap) { - for (doc <- documents) { - // parse the string and get a word then increment the df count for that word - if (doc._2.toLowerCase.split(" ").contains(token._1)) { - docFrequencyCount(dictIndex(token._1)) += 1 - } - } - } - - val docFrequencyMap = docFrequencyCount.zipWithIndex.map(x => x._2 -> x._1).toMap - - (dictIndex, docFrequencyMap) - } - - def vectorizeDocument(document: String, - dictionaryMap: Map[String, Int], - dfMap: Map[Int, Int], weight: TermWeight = new TFIDF): Vector = { - - val wordCounts = document.toLowerCase.split(" ").groupBy(identity).mapValues(_.length) - - val vec = new RandomAccessSparseVector(dictionaryMap.size) - - val totalDFSize = dictionaryMap.size - val docSize = wordCounts.size - - for (word <- wordCounts) { - val term = word._1 - if (dictionaryMap.contains(term)) { - val termFreq = word._2 - val dictIndex = dictionaryMap(term) - val docFreq = dfMap(dictIndex) - val currentWeight = weight.calculate(termFreq, docFreq.toInt, docSize, totalDFSize.toInt) - vec(dictIndex)= currentWeight - } - } - vec - } - - test("TF test") { - - val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) - - val tf: TermWeight = new TF() - - val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) - - for (doc <- documents) { - vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tf) - } - - // corpus: - // (1, "the first document contains 5 terms"), - // (2, "document two document contains 4 terms"), - // (3, "document three three terms"), - // (4, "each document including this document contain the term document") - - // dictonary: - // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, - // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) - - // dfMap: - // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, - // 11 -> 2, 8 -> 1, 4 -> 1) - - vectorizedDocuments(0, 0).toInt should be (0) - vectorizedDocuments(0, 13).toInt should be (1) - vectorizedDocuments(1, 3).toInt should be (2) - vectorizedDocuments(3, 3).toInt should be (3) - - } - - - test("TFIDF test") { - val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) - - val tfidf: TermWeight = new TFIDF() - - val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) - - for (doc <- documents) { - vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tfidf) - } - - // corpus: - // (1, "the first document contains 5 terms"), - // (2, "document two document contains 4 terms"), - // (3, "document three three terms"), - // (4, "each document including this document contain the term document") - - // dictonary: - // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, - // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) - - // dfMap: - // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, - // 11 -> 2, 8 -> 1, 4 -> 1) - - abs(vectorizedDocuments(0, 0) - 0.0) should be < epsilon - abs(vectorizedDocuments(0, 13) - 2.540445) should be < epsilon - abs(vectorizedDocuments(1, 3) - 2.870315) should be < epsilon - abs(vectorizedDocuments(3, 3) - 3.515403) should be < epsilon - } - - test("MLlib TFIDF test") { - val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) - - val tfidf: TermWeight = new MLlibTFIDF() - - val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) - - for (doc <- documents) { - vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tfidf) - } - - // corpus: - // (1, "the first document contains 5 terms"), - // (2, "document two document contains 4 terms"), - // (3, "document three three terms"), - // (4, "each document including this document contain the term document") - - // dictonary: - // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, - // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) - - // dfMap: - // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, - // 11 -> 2, 8 -> 1, 4 -> 1) - - abs(vectorizedDocuments(0, 0) - 0.0) should be < epsilon - abs(vectorizedDocuments(0, 13) - 1.609437) should be < epsilon - abs(vectorizedDocuments(1, 3) - 2.197224) should be < epsilon - abs(vectorizedDocuments(3, 3) - 3.295836) should be < epsilon - } - -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala b/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala deleted file mode 100644 index 3538991..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.test - -import org.apache.mahout.math.drm.DistributedContext -import org.scalatest.{Suite, FunSuite, Matchers} - -/** - * Unit tests that use a distributed context to run - */ -trait DistributedMahoutSuite extends MahoutSuite { this: Suite => - protected implicit var mahoutCtx: DistributedContext -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala b/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala deleted file mode 100644 index 7a34aa2..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala +++ /dev/null @@ -1,16 +0,0 @@ -package org.apache.mahout.test - -import org.scalatest._ -import org.apache.log4j.{Level, Logger, BasicConfigurator} - -trait LoggerConfiguration extends BeforeAndAfterAllConfigMap { - this: Suite => - - override protected def beforeAll(configMap: ConfigMap): Unit = { - super.beforeAll(configMap) - BasicConfigurator.resetConfiguration() - BasicConfigurator.configure() - Logger.getRootLogger.setLevel(Level.ERROR) - Logger.getLogger("org.apache.mahout.math.scalabindings").setLevel(Level.DEBUG) - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala b/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala deleted file mode 100644 index d3b8a38..0000000 --- a/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.mahout.test - -import java.io.File -import org.scalatest._ -import org.apache.mahout.common.RandomUtils - -trait MahoutSuite extends BeforeAndAfterEach with LoggerConfiguration with Matchers { - this: Suite => - - final val TmpDir = "tmp/" - - override protected def beforeEach() { - super.beforeEach() - RandomUtils.useTestSeed() - } - - override protected def beforeAll(configMap: ConfigMap) { - super.beforeAll(configMap) - - // just in case there is an existing tmp dir clean it before every suite - deleteDirectory(new File(TmpDir)) - } - - override protected def afterEach() { - - // clean the tmp dir after every test - deleteDirectory(new File(TmpDir)) - - super.afterEach() - } - - /** Delete directory no symlink checking and exceptions are not caught */ - private def deleteDirectory(path: File): Unit = { - if (path.isDirectory) - for (files <- path.listFiles) deleteDirectory(files) - path.delete - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 74da44e..7151414 100644 --- a/pom.xml +++ b/pom.xml @@ -212,12 +212,12 @@ </dependency> <dependency> - <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> + <artifactId>mahout-samsara_${scala.compat.version}</artifactId> <groupId>${project.groupId}</groupId> <version>${project.version}</version> </dependency> <dependency> - <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> + <artifactId>mahout-samsara_${scala.compat.version}</artifactId> <groupId>${project.groupId}</groupId> <version>${project.version}</version> <classifier>tests</classifier> @@ -772,7 +772,7 @@ <module>integration</module> <module>examples</module> <module>distribution</module> - <module>math-scala</module> + <module>samsara</module> <module>spark</module> <module>spark-shell</module> <module>h2o</module> http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/pom.xml ---------------------------------------------------------------------- diff --git a/samsara/pom.xml b/samsara/pom.xml new file mode 100644 index 0000000..5f80b68 --- /dev/null +++ b/samsara/pom.xml @@ -0,0 +1,194 @@ +<?xml version="1.0" encoding="UTF-8"?> + +<!-- + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.mahout</groupId> + <artifactId>mahout</artifactId> + <version>0.11.0-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <name>Mahout Samsara</name> + <description>Mahout Math Scala bindings</description> + + <packaging>jar</packaging> + + <build> + <plugins> + <!-- create test jar so other modules can reuse the samsara test utility classes. --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + <phase>package</phase> + </execution> + </executions> + </plugin> + + <plugin> + <artifactId>maven-javadoc-plugin</artifactId> + </plugin> + + <plugin> + <artifactId>maven-source-plugin</artifactId> + </plugin> + + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <executions> + <execution> + <id>add-scala-sources</id> + <phase>initialize</phase> + <goals> + <goal>add-source</goal> + </goals> + </execution> + <execution> + <id>scala-compile</id> + <phase>process-resources</phase> + <goals> + <goal>compile</goal> + </goals> + </execution> + <execution> + <id>scala-test-compile</id> + <phase>process-test-resources</phase> + <goals> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + </plugin> + + <!--this is what scalatest recommends to do to enable scala tests --> + + <!-- disable surefire --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <configuration> + <skipTests>true</skipTests> + </configuration> + </plugin> + <!-- enable scalatest --> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <executions> + <execution> + <id>test</id> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + + </plugins> + </build> + + <dependencies> + + <dependency> + <groupId>org.apache.mahout</groupId> + <artifactId>mahout-math</artifactId> + </dependency> + + <!-- 3rd-party --> + <dependency> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + </dependency> + + <dependency> + <groupId>com.github.scopt</groupId> + <artifactId>scopt_${scala.compat.version}</artifactId> + <version>3.3.0</version> + </dependency> + + <!-- scala stuff --> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-compiler</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-reflect</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-actors</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scalap</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.compat.version}</artifactId> + </dependency> + + </dependencies> + + <profiles> + <profile> + <id>mahout-release</id> + <build> + <plugins> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <executions> + <execution> + <id>generate-scaladoc</id> + <goals> + <goal>doc</goal> + </goals> + </execution> + <execution> + <id>attach-scaladoc-jar</id> + <goals> + <goal>doc-jar</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> +</project> http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala new file mode 100644 index 0000000..5de0733 --- /dev/null +++ b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala @@ -0,0 +1,119 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math.Vector +import scala.collection.JavaConversions._ + +/** + * Abstract Classifier base for Complentary and Standard Classifiers + * @param nbModel a trained NBModel + */ +abstract class AbstractNBClassifier(nbModel: NBModel) extends java.io.Serializable { + + // Trained Naive Bayes Model + val model = nbModel + + /** scoring method for standard and complementary classifiers */ + protected def getScoreForLabelFeature(label: Int, feature: Int): Double + + /** getter for model */ + protected def getModel: NBModel= { + model + } + + /** + * Compute the score for a Vector of weighted TF-IDF featured + * @param label Label to be scored + * @param instance Vector of weights to be calculate score + * @return score for this Label + */ + protected def getScoreForLabelInstance(label: Int, instance: Vector): Double = { + var result: Double = 0.0 + for (e <- instance.nonZeroes) { + result += e.get * getScoreForLabelFeature(label, e.index) + } + result + } + + /** number of categories the model has been trained on */ + def numCategories: Int = { + model.numLabels + } + + /** + * get a scoring vector for a vector of TF of TF-IDF weights + * @param instance vector of TF of TF-IDF weights to be classified + * @return a vector of scores. + */ + def classifyFull(instance: Vector): Vector = { + classifyFull(model.createScoringVector, instance) + } + + /** helper method for classifyFull(Vector) */ + def classifyFull(r: Vector, instance: Vector): Vector = { + var label: Int = 0 + for (label <- 0 until model.numLabels) { + r.setQuick(label, getScoreForLabelInstance(label, instance)) + } + r + } +} + +/** + * Standard Multinomial Naive Bayes Classifier + * @param nbModel a trained NBModel + */ +class StandardNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable{ + override def getScoreForLabelFeature(label: Int, feature: Int): Double = { + val model: NBModel = getModel + StandardNBClassifier.computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI, model.numFeatures) + } +} + +/** helper object for StandardNBClassifier */ +object StandardNBClassifier extends java.io.Serializable { + /** Compute Standard Multinomial Naive Bayes Weights See Rennie et. al. Section 2.1 */ + def computeWeight(featureLabelWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { + val numerator: Double = featureLabelWeight + alphaI + val denominator: Double = labelWeight + alphaI * numFeatures + return Math.log(numerator / denominator) + } +} + +/** + * Complementary Naive Bayes Classifier + * @param nbModel a trained NBModel + */ +class ComplementaryNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable { + override def getScoreForLabelFeature(label: Int, feature: Int): Double = { + val model: NBModel = getModel + val weight: Double = ComplementaryNBClassifier.computeWeight(model.featureWeight(feature), model.weight(label, feature), model.totalWeightSum, model.labelWeight(label), model.alphaI, model.numFeatures) + return weight / model.thetaNormalizer(label) + } +} + +/** helper object for ComplementaryNBClassifier */ +object ComplementaryNBClassifier extends java.io.Serializable { + + /** Compute Complementary weights See Rennie et. al. Section 3.1 */ + def computeWeight(featureWeight: Double, featureLabelWeight: Double, totalWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { + val numerator: Double = featureWeight - featureLabelWeight + alphaI + val denominator: Double = totalWeight - labelWeight + alphaI * numFeatures + return -Math.log(numerator / denominator) + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala new file mode 100644 index 0000000..3ceae96 --- /dev/null +++ b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math._ + +import org.apache.mahout.math.{drm, scalabindings} + +import scalabindings._ +import scalabindings.RLikeOps._ +import drm.RLikeDrmOps._ +import drm._ +import scala.collection.JavaConverters._ +import scala.language.asInstanceOf +import scala.collection._ +import JavaConversions._ + +/** + * + * @param weightsPerLabelAndFeature Aggregated matrix of weights of labels x features + * @param weightsPerFeature Vector of summation of all feature weights. + * @param weightsPerLabel Vector of summation of all label weights. + * @param perlabelThetaNormalizer Vector of weight normalizers per label (used only for complemtary models) + * @param labelIndex HashMap of labels and their corresponding row in the weightMatrix + * @param alphaI Laplace smoothing factor. + * @param isComplementary Whether or not this is a complementary model. + */ +class NBModel(val weightsPerLabelAndFeature: Matrix = null, + val weightsPerFeature: Vector = null, + val weightsPerLabel: Vector = null, + val perlabelThetaNormalizer: Vector = null, + val labelIndex: Map[String, Integer] = null, + val alphaI: Float = 1.0f, + val isComplementary: Boolean= false) extends java.io.Serializable { + + + val numFeatures: Double = weightsPerFeature.getNumNondefaultElements + val totalWeightSum: Double = weightsPerLabel.zSum + val alphaVector: Vector = null + + validate() + + // todo: Maybe it is a good idea to move the dfsWrite and dfsRead out + // todo: of the model and into a helper + + // TODO: weightsPerLabelAndFeature, a sparse (numFeatures x numLabels) matrix should fit + // TODO: upfront in memory and should not require a DRM decide if we want this to scale out. + + + /** getter for summed label weights. Used by legacy classifier */ + def labelWeight(label: Int): Double = { + weightsPerLabel.getQuick(label) + } + + /** getter for weight normalizers. Used by legacy classifier */ + def thetaNormalizer(label: Int): Double = { + perlabelThetaNormalizer.get(label) + } + + /** getter for summed feature weights. Used by legacy classifier */ + def featureWeight(feature: Int): Double = { + weightsPerFeature.getQuick(feature) + } + + /** getter for individual aggregated weights. Used by legacy classifier */ + def weight(label: Int, feature: Int): Double = { + weightsPerLabelAndFeature.getQuick(label, feature) + } + + /** getter for a single empty vector of weights */ + def createScoringVector: Vector = { + weightsPerLabel.like + } + + /** getter for a the number of labels to consider */ + def numLabels: Int = { + weightsPerLabel.size + } + + /** + * Write a trained model to the filesystem as a series of DRMs + * @param pathToModel Directory to which the model will be written + */ + def dfsWrite(pathToModel: String)(implicit ctx: DistributedContext): Unit = { + //todo: write out as smaller partitions or possibly use reader and writers to + //todo: write something other than a DRM for label Index, is Complementary, alphaI. + + // add a directory to put all of the DRMs in + val fullPathToModel = pathToModel + NBModel.modelBaseDirectory + + drmParallelize(weightsPerLabelAndFeature).dfsWrite(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm") + drmParallelize(sparse(weightsPerFeature)).dfsWrite(fullPathToModel + "/weightsPerFeatureDrm.drm") + drmParallelize(sparse(weightsPerLabel)).dfsWrite(fullPathToModel + "/weightsPerLabelDrm.drm") + drmParallelize(sparse(perlabelThetaNormalizer)).dfsWrite(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") + drmParallelize(sparse(svec((0,alphaI)::Nil))).dfsWrite(fullPathToModel + "/alphaIDrm.drm") + + // isComplementry is true if isComplementaryDrm(0,0) == 1 else false + val isComplementaryDrm = sparse(0 to 1, 0 to 1) + if(isComplementary){ + isComplementaryDrm(0,0) = 1.0 + } else { + isComplementaryDrm(0,0) = 0.0 + } + drmParallelize(isComplementaryDrm).dfsWrite(fullPathToModel + "/isComplementaryDrm.drm") + + // write the label index as a String-Keyed DRM. + val labelIndexDummyDrm = weightsPerLabelAndFeature.like() + labelIndexDummyDrm.setRowLabelBindings(labelIndex) + // get a reverse map of [Integer, String] and set the value of firsr column of the drm + // to the corresponding row number for it's Label (the rows may not be read back in the same order) + val revMap = labelIndex.map(x => x._2 -> x._1) + for(i <- 0 until labelIndexDummyDrm.numRows() ){ + labelIndexDummyDrm.set(labelIndex(revMap(i)), 0, i.toDouble) + } + + drmParallelizeWithRowLabels(labelIndexDummyDrm).dfsWrite(fullPathToModel + "/labelIndex.drm") + } + + /** Model Validation */ + def validate() { + assert(alphaI > 0, "alphaI has to be greater than 0!") + assert(numFeatures > 0, "the vocab count has to be greater than 0!") + assert(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!") + assert(weightsPerLabel != null, "the number of labels has to be defined!") + assert(weightsPerLabel.getNumNondefaultElements > 0, "the number of labels has to be greater than 0!") + assert(weightsPerFeature != null, "the feature sums have to be defined") + assert(weightsPerFeature.getNumNondefaultElements > 0, "the feature sums have to be greater than 0!") + if (isComplementary) { + assert(perlabelThetaNormalizer != null, "the theta normalizers have to be defined") + assert(perlabelThetaNormalizer.getNumNondefaultElements > 0, "the number of theta normalizers has to be greater than 0!") + assert(Math.signum(perlabelThetaNormalizer.minValue) == Math.signum(perlabelThetaNormalizer.maxValue), "Theta normalizers do not all have the same sign") + assert(perlabelThetaNormalizer.getNumNonZeroElements == perlabelThetaNormalizer.size, "Weight normalizers can not have zero value.") + } + assert(labelIndex.size == weightsPerLabel.getNumNondefaultElements, "label index must have entries for all labels") + } +} + +object NBModel extends java.io.Serializable { + + val modelBaseDirectory = "/naiveBayesModel" + + /** + * Read a trained model in from from the filesystem. + * @param pathToModel directory from which to read individual model components + * @return a valid NBModel + */ + def dfsRead(pathToModel: String)(implicit ctx: DistributedContext): NBModel = { + //todo: Takes forever to read we need a more practical method of writing models. Readers/Writers? + + // read from a base directory for all drms + val fullPathToModel = pathToModel + modelBaseDirectory + + val weightsPerFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerFeature = weightsPerFeatureDrm.collect(0, ::) + weightsPerFeatureDrm.uncache() + + val weightsPerLabelDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerLabel = weightsPerLabelDrm.collect(0, ::) + weightsPerLabelDrm.uncache() + + val alphaIDrm = drmDfsRead(fullPathToModel + "/alphaIDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val alphaI: Float = alphaIDrm.collect(0, 0).toFloat + alphaIDrm.uncache() + + // isComplementry is true if isComplementaryDrm(0,0) == 1 else false + val isComplementaryDrm = drmDfsRead(fullPathToModel + "/isComplementaryDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val isComplementary = isComplementaryDrm.collect(0, 0).toInt == 1 + isComplementaryDrm.uncache() + + var perLabelThetaNormalizer= weightsPerFeature.like() + if (isComplementary) { + val perLabelThetaNormalizerDrm = drm.drmDfsRead(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") + .checkpoint(CacheHint.MEMORY_ONLY) + perLabelThetaNormalizer = perLabelThetaNormalizerDrm.collect(0, ::) + } + + val dummyLabelDrm= drmDfsRead(fullPathToModel + "/labelIndex.drm") + .checkpoint(CacheHint.MEMORY_ONLY) + val labelIndexMap:java.util.Map[String, Integer] = dummyLabelDrm.getRowLabelBindings + dummyLabelDrm.uncache() + + // map the labels to the corresponding row numbers of weightsPerFeatureDrm (values in dummyLabelDrm) + val scalaLabelIndexMap: mutable.Map[String, Integer] = + labelIndexMap.map(x => x._1 -> dummyLabelDrm.get(labelIndexMap(x._1), 0) + .toInt + .asInstanceOf[Integer]) + + val weightsPerLabelAndFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerLabelAndFeature = weightsPerLabelAndFeatureDrm.collect + weightsPerLabelAndFeatureDrm.uncache() + + // model validation is triggered automatically by constructor + val model: NBModel = new NBModel(weightsPerLabelAndFeature, + weightsPerFeature, + weightsPerLabel, + perLabelThetaNormalizer, + scalaLabelIndexMap, + alphaI, + isComplementary) + + model + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala new file mode 100644 index 0000000..a15ca09 --- /dev/null +++ b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.classifier.stats.{ResultAnalyzer, ClassifierResult} +import org.apache.mahout.math._ +import scalabindings._ +import scalabindings.RLikeOps._ +import drm.RLikeDrmOps._ +import drm._ +import scala.reflect.ClassTag +import scala.language.asInstanceOf +import collection._ +import scala.collection.JavaConversions._ + +/** + * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor + * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + */ +trait NaiveBayes extends java.io.Serializable{ + + /** default value for the Laplacian smoothing parameter */ + def defaultAlphaI = 1.0f + + // function to extract categories from string keys + type CategoryParser = String => String + + /** Default: seqdirectory/seq2Sparse Categories are Stored in Drm Keys as: /Category/document_id */ + def seq2SparseCategoryParser: CategoryParser = x => x.split("/")(1) + + + /** + * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor + * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + * + * @param observationsPerLabel a DrmLike[Int] matrix containing term frequency counts for each label. + * @param trainComplementary whether or not to train a complementary Naive Bayes model + * @param alphaI Laplace smoothing parameter + * @return trained naive bayes model + */ + def train(observationsPerLabel: DrmLike[Int], + labelIndex: Map[String, Integer], + trainComplementary: Boolean = true, + alphaI: Float = defaultAlphaI): NBModel = { + + // Summation of all weights per feature + val weightsPerFeature = observationsPerLabel.colSums + + // Distributed summation of all weights per label + val weightsPerLabel = observationsPerLabel.rowSums + + // Collect a matrix to pass to the NaiveBayesModel + val inCoreTFIDF = observationsPerLabel.collect + + // perLabelThetaNormalizer Vector is expected by NaiveBayesModel. We can pass a null value + // or Vector of zeroes in the case of a standard NB model. + var thetaNormalizer = weightsPerFeature.like() + + // Instantiate a trainer and retrieve the perLabelThetaNormalizer Vector from it in the case of + // a complementary NB model + if (trainComplementary) { + val thetaTrainer = new ComplementaryNBThetaTrainer(weightsPerFeature, + weightsPerLabel, + alphaI) + // local training of the theta normalization + for (labelIndex <- 0 until inCoreTFIDF.nrow) { + thetaTrainer.train(labelIndex, inCoreTFIDF(labelIndex, ::)) + } + thetaNormalizer = thetaTrainer.retrievePerLabelThetaNormalizer + } + + new NBModel(inCoreTFIDF, + weightsPerFeature, + weightsPerLabel, + thetaNormalizer, + labelIndex, + alphaI, + trainComplementary) + } + + /** + * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse + * and aggregate TF or TF-IDF values by their label + * Override this method in engine specific modules to optimize + * + * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse + * in form K = eg./Category/document_title + * V = TF or TF-IDF values per term + * @param cParser a String => String function used to extract categories from + * Keys of the stringKeyedObservations DRM. The default + * CategoryParser will extract "Category" from: '/Category/document_id' + * @return (labelIndexMap,aggregatedByLabelObservationDrm) + * labelIndexMap is a HashMap [String, Integer] K = label row index + * V = label + * aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated + * TF or TF-IDF counts per label + */ + def extractLabelsAndAggregateObservations[K: ClassTag](stringKeyedObservations: DrmLike[K], + cParser: CategoryParser = seq2SparseCategoryParser) + (implicit ctx: DistributedContext): + (mutable.HashMap[String, Integer], DrmLike[Int])= { + + stringKeyedObservations.checkpoint() + + val numDocs=stringKeyedObservations.nrow + val numFeatures=stringKeyedObservations.ncol + + // Extract categories from labels assigned by seq2sparse + // Categories are Stored in Drm Keys as eg.: /Category/document_id + + // Get a new DRM with a single column so that we don't have to collect the + // DRM into memory upfront. + val strippedObeservations= stringKeyedObservations.mapBlock(ncol=1){ + case(keys, block) => + val blockB = block.like(keys.size, 1) + keys -> blockB + } + + // Extract the row label bindings (the String keys) from the slim Drm + // strip the document_id from the row keys keeping only the category. + // Sort the bindings alphabetically into a Vector + val labelVectorByRowIndex = strippedObeservations + .getRowLabelBindings + .map(x => x._2 -> cParser(x._1)) + .toVector.sortWith(_._1 < _._1) + + //TODO: add a .toIntKeyed(...) method to DrmLike? + + // Copy stringKeyedObservations to an Int-Keyed Drm so that we can compute transpose + // Copy the Collected Matrices up front for now until we hav a distributed way of converting + val inCoreStringKeyedObservations = stringKeyedObservations.collect + val inCoreIntKeyedObservations = new SparseMatrix( + stringKeyedObservations.nrow.toInt, + stringKeyedObservations.ncol) + for (i <- 0 until inCoreStringKeyedObservations.nrow.toInt) { + inCoreIntKeyedObservations(i, ::) = inCoreStringKeyedObservations(i, ::) + } + + val intKeyedObservations= drmParallelize(inCoreIntKeyedObservations) + + stringKeyedObservations.uncache() + + var labelIndex = 0 + val labelIndexMap = new mutable.HashMap[String, Integer] + val encodedLabelByRowIndexVector = new DenseVector(labelVectorByRowIndex.size) + + // Encode Categories as an Integer (Double) so we can broadcast as a vector + // where each element is an Int-encoded category whose index corresponds + // to its row in the Drm + for (i <- 0 until labelVectorByRowIndex.size) { + if (!(labelIndexMap.contains(labelVectorByRowIndex(i)._2))) { + encodedLabelByRowIndexVector(i) = labelIndex.toDouble + labelIndexMap.put(labelVectorByRowIndex(i)._2, labelIndex) + labelIndex += 1 + } + // don't like this casting but need to use a java.lang.Integer when setting rowLabelBindings + encodedLabelByRowIndexVector(i) = labelIndexMap + .getOrElse(labelVectorByRowIndex(i)._2, -1) + .asInstanceOf[Int].toDouble + } + + // "Combiner": Map and aggregate by Category. Do this by broadcasting the encoded + // category vector and mapping a transposed IntKeyed Drm out so that all categories + // will be present on all nodes as columns and can be referenced by + // BCastEncodedCategoryByRowVector. Iteratively sum all categories. + val nLabels = labelIndex + + val bcastEncodedCategoryByRowVector = drmBroadcast(encodedLabelByRowIndexVector) + + val aggregetedObservationByLabelDrm = intKeyedObservations.t.mapBlock(ncol = nLabels) { + case (keys, blockA) => + val blockB = blockA.like(keys.size, nLabels) + var label : Int = 0 + for (i <- 0 until keys.size) { + blockA(i, ::).nonZeroes().foreach { elem => + label = bcastEncodedCategoryByRowVector.get(elem.index).toInt + blockB(i, label) = blockB(i, label) + blockA(i, elem.index) + } + } + keys -> blockB + }.t + + (labelIndexMap, aggregetedObservationByLabelDrm) + } + + /** + * Test a trained model with a labeled dataset sequentially + * @param model a trained NBModel + * @param testSet a labeled testing set + * @param testComplementary test using a complementary or a standard NB classifier + * @param cParser a String => String function used to extract categories from + * Keys of the testing set DRM. The default + * CategoryParser will extract "Category" from: '/Category/document_id' + * + * *Note*: this method brings the entire test set into upfront memory, + * This method is optimized and parallelized in SparkNaiveBayes + * + * @tparam K implicitly determined Key type of test set DRM: String + * @return a result analyzer with confusion matrix and accuracy statistics + */ + def test[K: ClassTag](model: NBModel, + testSet: DrmLike[K], + testComplementary: Boolean = false, + cParser: CategoryParser = seq2SparseCategoryParser) + (implicit ctx: DistributedContext): ResultAnalyzer = { + + val labelMap = model.labelIndex + + val numLabels = model.numLabels + + testSet.checkpoint() + + val numTestInstances = testSet.nrow.toInt + + // instantiate the correct type of classifier + val classifier = testComplementary match { + case true => new ComplementaryNBClassifier(model) with Serializable + case _ => new StandardNBClassifier(model) with Serializable + } + + if (testComplementary) { + assert(testComplementary == model.isComplementary, + "Complementary Label Assignment requires Complementary Training") + } + + + // Sequentially assign labels to the test set: + // *Note* this brings the entire test set into memory upfront: + + // Since we cant broadcast the model as is do it sequentially up front for now + val inCoreTestSet = testSet.collect + + // get the labels of the test set and extract the keys + val testSetLabelMap = testSet.getRowLabelBindings + + // empty Matrix in which we'll set the classification scores + val inCoreScoredTestSet = testSet.like(numTestInstances, numLabels) + + testSet.uncache() + + for (i <- 0 until numTestInstances) { + inCoreScoredTestSet(i, ::) := classifier.classifyFull(inCoreTestSet(i, ::)) + } + + // todo: reverse the labelMaps in training and through the model? + + // reverse the label map and extract the labels + val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1)) + + val reverseLabelMap = labelMap.map(x => x._2 -> x._1) + + val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT") + + // assign labels- winner takes all + for (i <- 0 until numTestInstances) { + val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i, ::)) + val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore) + analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult) + } + + analyzer + } + + /** + * argmax with values as well + * returns a tuple of index of the max score and the score itself. + * @param v Vector of of scores + * @return (bestIndex, bestScore) + */ + def argmax(v: Vector): (Int, Double) = { + var bestIdx: Int = Integer.MIN_VALUE + var bestScore: Double = Integer.MIN_VALUE.asInstanceOf[Int].toDouble + for(i <- 0 until v.size) { + if(v(i) > bestScore){ + bestScore = v(i) + bestIdx = i + } + } + (bestIdx, bestScore) + } + +} + +object NaiveBayes extends NaiveBayes with java.io.Serializable + +/** + * Trainer for the weight normalization vector used by Transform Weight Normalized Complement + * Naive Bayes. See: Rennie et.al.: Tackling the poor assumptions of Naive Bayes Text classifiers, + * ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf Sec. 3.2. + * + * @param weightsPerFeature a Vector of summed TF or TF-IDF weights for each word in dictionary. + * @param weightsPerLabel a Vector of summed TF or TF-IDF weights for each label. + * @param alphaI Laplace smoothing factor. Defaut value of 1. + */ +class ComplementaryNBThetaTrainer(private val weightsPerFeature: Vector, + private val weightsPerLabel: Vector, + private val alphaI: Double = 1.0) { + + private val perLabelThetaNormalizer: Vector = weightsPerLabel.like() + private val totalWeightSum: Double = weightsPerLabel.zSum + private var numFeatures: Double = weightsPerFeature.getNumNondefaultElements + + assert(weightsPerFeature != null, "weightsPerFeature vector can not be null") + assert(weightsPerLabel != null, "weightsPerLabel vector can not be null") + + /** + * Train the weight normalization vector for each label + * @param label + * @param featurePerLabelWeight + */ + def train(label: Int, featurePerLabelWeight: Vector) { + val currentLabelWeight = labelWeight(label) + // sum weights for each label including those with zero word counts + for (i <- 0 until featurePerLabelWeight.size) { + val currentFeaturePerLabelWeight = featurePerLabelWeight(i) + updatePerLabelThetaNormalizer(label, + ComplementaryNBClassifier.computeWeight(featureWeight(i), + currentFeaturePerLabelWeight, + totalWeightSum, + currentLabelWeight, + alphaI, + numFeatures) + ) + } + } + + /** + * getter for summed TF or TF-IDF weights by label + * @param label index of label + * @return sum of word TF or TF-IDF weights for label + */ + def labelWeight(label: Int): Double = { + weightsPerLabel(label) + } + + /** + * getter for summed TF or TF-IDF weights by word. + * @param feature index of word. + * @return sum of TF or TF-IDF weights for word. + */ + def featureWeight(feature: Int): Double = { + weightsPerFeature(feature) + } + + /** + * add the magnitude of the current weight to the current + * label's corresponding Vector element. + * @param label index of label to update. + * @param weight weight to add. + */ + def updatePerLabelThetaNormalizer(label: Int, weight: Double) { + perLabelThetaNormalizer(label) = perLabelThetaNormalizer(label) + Math.abs(weight) + } + + /** + * Getter for the weight normalizer vector as indexed by label + * @return a copy of the weight normalizer vector. + */ + def retrievePerLabelThetaNormalizer: Vector = { + perLabelThetaNormalizer.cloned + } + + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala new file mode 100644 index 0000000..8f1413a --- /dev/null +++ b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala @@ -0,0 +1,467 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.classifier.stats + +import java.text.{DecimalFormat, NumberFormat} +import java.util +import org.apache.mahout.math.stats.OnlineSummarizer + + +/** + * Result of a document classification. The label and the associated score (usually probabilty) + */ +class ClassifierResult (private var label: String = null, + private var score: Double = 0.0, + private var logLikelihood: Double = Integer.MAX_VALUE.toDouble) { + + def getLogLikelihood: Double = logLikelihood + + def setLogLikelihood(llh: Double) { + logLikelihood = llh + } + + def getLabel: String = label + + def getScore: Double = score + + def setLabel(lbl: String) { + label = lbl + } + + def setScore(sc: Double) { + score = sc + } + + override def toString: String = { + "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}' + } + +} + +/** + * ResultAnalyzer captures the classification statistics and displays in a tabular manner + * @param labelSet Set of labels to be considered in classification + * @param defaultLabel the default label for an unknown classification + */ +class ResultAnalyzer(private val labelSet: util.Collection[String], defaultLabel: String) { + + val confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel) + val summarizer = new OnlineSummarizer + + private var hasLL: Boolean = false + private var correctlyClassified: Int = 0 + private var incorrectlyClassified: Int = 0 + + + def getConfusionMatrix: ConfusionMatrix = confusionMatrix + + /** + * + * @param correctLabel + * The correct label + * @param classifiedResult + * The classified result + * @return whether the instance was correct or not + */ + def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Boolean = { + val result: Boolean = correctLabel == classifiedResult.getLabel + if (result) { + correctlyClassified += 1 + } + else { + incorrectlyClassified += 1 + } + confusionMatrix.addInstance(correctLabel, classifiedResult) + if (classifiedResult.getLogLikelihood != Integer.MAX_VALUE.toDouble) { + summarizer.add(classifiedResult.getLogLikelihood) + hasLL = true + } + + result + } + + /** Dump the resulting statistics to a string */ + override def toString: String = { + val returnString: StringBuilder = new StringBuilder + returnString.append('\n') + returnString.append("=======================================================\n") + returnString.append("Summary\n") + returnString.append("-------------------------------------------------------\n") + val totalClassified: Int = correctlyClassified + incorrectlyClassified + val percentageCorrect: Double = 100.asInstanceOf[Double] * correctlyClassified / totalClassified + val percentageIncorrect: Double = 100.asInstanceOf[Double] * incorrectlyClassified / totalClassified + val decimalFormatter: NumberFormat = new DecimalFormat("0.####") + returnString.append("Correctly Classified Instances") + .append(": ") + .append(Integer.toString(correctlyClassified)) + .append('\t') + .append(decimalFormatter.format(percentageCorrect)) + .append("%\n") + returnString.append("Incorrectly Classified Instances") + .append(": ") + .append(Integer.toString(incorrectlyClassified)) + .append('\t') + .append(decimalFormatter.format(percentageIncorrect)) + .append("%\n") + returnString.append("Total Classified Instances") + .append(": ") + .append(Integer.toString(totalClassified)) + .append('\n') + returnString.append('\n') + returnString.append(confusionMatrix) + returnString.append("=======================================================\n") + returnString.append("Statistics\n") + returnString.append("-------------------------------------------------------\n") + val normStats: RunningAverageAndStdDev = confusionMatrix.getNormalizedStats + returnString.append("Kappa: \t") + .append(decimalFormatter.format(confusionMatrix.getKappa)) + .append('\n') + returnString.append("Accuracy: \t") + .append(decimalFormatter.format(confusionMatrix.getAccuracy)) + .append("%\n") + returnString.append("Reliability: \t") + .append(decimalFormatter.format(normStats.getAverage * 100.00000001)) + .append("%\n") + returnString.append("Reliability (std dev): \t") + .append(decimalFormatter.format(normStats.getStandardDeviation)) + .append('\n') + returnString.append("Weighted precision: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedPrecision)) + .append('\n') + returnString.append("Weighted recall: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedRecall)) + .append('\n') + returnString.append("Weighted F1 score: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedF1score)) + .append('\n') + if (hasLL) { + returnString.append("Log-likelihood: \t") + .append("mean : \t") + .append(decimalFormatter.format(summarizer.getMean)) + .append('\n') + returnString.append("25%-ile : \t") + .append(decimalFormatter.format(summarizer.getQuartile(1))) + .append('\n') + returnString.append("75%-ile : \t") + .append(decimalFormatter.format(summarizer.getQuartile(3))) + .append('\n') + } + + returnString.toString() + } + + +} + +/** + * + * Interface for classes that can keep track of a running average of a series of numbers. One can add to or + * remove from the series, as well as update a datum in the series. The class does not actually keep track of + * the series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * + * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverage.java + */ +trait RunningAverage { + + /** + * @param datum + * new item to add to the running average + * @throws IllegalArgumentException + * if datum is { @link Double#NaN} + */ + def addDatum(datum: Double) + + /** + * @param datum + * item to remove to the running average + * @throws IllegalArgumentException + * if datum is { @link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + def removeDatum(datum: Double) + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalArgumentException + * if delta is { @link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + def changeDatum(delta: Double) + + def getCount: Int + + def getAverage: Double + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + def inverse: RunningAverage +} + +/** + * + * Extends {@link RunningAverage} by adding standard deviation too. + * + * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev.java + */ +trait RunningAverageAndStdDev extends RunningAverage { + + /** @return standard deviation of data */ + def getStandardDeviation: Double + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + def inverse: RunningAverageAndStdDev +} + + +class InvertedRunningAverage(private val delegate: RunningAverage) extends RunningAverage { + + override def addDatum(datum: Double) { + throw new UnsupportedOperationException + } + + override def removeDatum(datum: Double) { + throw new UnsupportedOperationException + } + + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + override def getCount: Int = { + delegate.getCount + } + + override def getAverage: Double = { + -delegate.getAverage + } + + override def inverse: RunningAverage = { + delegate + } +} + + +/** + * + * A simple class that can keep track of a running average of a series of numbers. One can add to or remove + * from the series, as well as update a datum in the series. The class does not actually keep track of the + * series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * + * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverage.java + */ +class FullRunningAverage(private var count: Int = 0, + private var average: Double = Double.NaN ) extends RunningAverage { + + /** + * @param datum + * new item to add to the running average + */ + override def addDatum(datum: Double) { + count += 1 + if (count == 1) { + average = datum + } + else { + average = average * (count - 1) / count + datum / count + } + } + + /** + * @param datum + * item to remove from the running average + * @throws IllegalStateException + * if count is 0 + */ + override def removeDatum(datum: Double) { + if (count == 0) { + throw new IllegalStateException + } + count -= 1 + if (count == 0) { + average = Double.NaN + } + else { + average = average * (count + 1) / count - datum / count + } + } + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalStateException + * if count is 0 + */ + override def changeDatum(delta: Double) { + if (count == 0) { + throw new IllegalStateException + } + average += delta / count + } + + override def getCount: Int = { + count + } + + override def getAverage: Double = { + average + } + + override def inverse: RunningAverage = { + new InvertedRunningAverage(this) + } + + override def toString: String = { + String.valueOf(average) + } +} + + +/** + * + * Extends {@link FullRunningAverage} to add a running standard deviation computation. + * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html + * + * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev.java + */ +class FullRunningAverageAndStdDev(private var count: Int = 0, + private var average: Double = 0.0, + private var mk: Double = 0.0, + private var sk: Double = 0.0) extends FullRunningAverage with RunningAverageAndStdDev { + + var stdDev: Double = 0.0 + + recomputeStdDev + + def getMk: Double = { + mk + } + + def getSk: Double = { + sk + } + + override def getStandardDeviation: Double = { + stdDev + } + + override def addDatum(datum: Double) { + super.addDatum(datum) + val count: Int = getCount + if (count == 1) { + mk = datum + sk = 0.0 + } + else { + val oldmk: Double = mk + val diff: Double = datum - oldmk + mk += diff / count + sk += diff * (datum - mk) + } + recomputeStdDev + } + + override def removeDatum(datum: Double) { + val oldCount: Int = getCount + super.removeDatum(datum) + val oldmk: Double = mk + mk = (oldCount * oldmk - datum) / (oldCount - 1) + sk -= (datum - mk) * (datum - oldmk) + recomputeStdDev + } + + /** + * @throws UnsupportedOperationException + */ + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + private def recomputeStdDev { + val count: Int = getCount + stdDev = if (count > 1) Math.sqrt(sk / (count - 1)) else Double.NaN + } + + override def inverse: RunningAverageAndStdDev = { + new InvertedRunningAverageAndStdDev(this) + } + + override def toString: String = { + String.valueOf(String.valueOf(getAverage) + ',' + stdDev) + } + +} + + +/** + * + * @param delegate RunningAverageAndStdDev instance + * + * Ported from org.apache.mahout.cf.taste.impl.common.InvertedRunningAverageAndStdDev.java + */ +class InvertedRunningAverageAndStdDev(private val delegate: RunningAverageAndStdDev) extends RunningAverageAndStdDev { + + /** + * @throws UnsupportedOperationException + */ + override def addDatum(datum: Double) { + throw new UnsupportedOperationException + } + + /** + * @throws UnsupportedOperationException + */ + + override def removeDatum(datum: Double) { + throw new UnsupportedOperationException + } + + /** + * @throws UnsupportedOperationException + */ + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + override def getCount: Int = { + delegate.getCount + } + + override def getAverage: Double = { + -delegate.getAverage + } + + override def getStandardDeviation: Double = { + delegate.getStandardDeviation + } + + override def inverse: RunningAverageAndStdDev = { + delegate + } +} + + + + http://git-wip-us.apache.org/repos/asf/mahout/blob/f7b69fab/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala new file mode 100644 index 0000000..328d27b --- /dev/null +++ b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala @@ -0,0 +1,460 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.classifier.stats + +import java.util +import org.apache.commons.math3.stat.descriptive.moment.Mean // This is brought in by mahout-math +import org.apache.mahout.math.{DenseMatrix, Matrix} +import scala.collection.mutable +import scala.collection.JavaConversions._ + +/** + * + * Ported from org.apache.mahout.classifier.ConfusionMatrix.java + * + * The ConfusionMatrix Class stores the result of Classification of a Test Dataset. + * + * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default. + * + * See http://en.wikipedia.org/wiki/Confusion_matrix for background + * + * + * @param labels The labels to consider for classification + * @param defaultLabel default unknown label + */ +class ConfusionMatrix(private var labels: util.Collection[String] = null, + private var defaultLabel: String = "unknown") { + /** + * Matrix Constructor + * @param m a DenseMatrix with RowLabelBindings + */ +// def this(m: Matrix) { +// this() +// confusionMatrix = Array.ofDim[Int](m.numRows, m.numRows) +// setMatrix(m) +// } + + // val LOG: Logger = LoggerFactory.getLogger(classOf[ConfusionMatrix]) + + var confusionMatrix = Array.ofDim[Int](labels.size + 1, labels.size + 1) + + val labelMap = new mutable.HashMap[String,Integer]() + + var samples: Int = 0 + + var i: Integer = 0 + for (label <- labels) { + labelMap.put(label, i) + i+=1 + } + labelMap.put(defaultLabel, i) + + + def getConfusionMatrix: Array[Array[Int]] = confusionMatrix + + def getLabels = labelMap.keys.toList + + def numLabels: Int = labelMap.size + + def getAccuracy(label: String): Double = { + val labelId: Int = labelMap(label) + var labelTotal: Int = 0 + var correct: Int = 0 + for (i <- 0 until numLabels) { + labelTotal += confusionMatrix(labelId)(i) + if (i == labelId) { + correct += confusionMatrix(labelId)(i) + } + } + + 100.0 * correct / labelTotal + } + + def getAccuracy: Double = { + var total: Int = 0 + var correct: Int = 0 + for (i <- 0 until numLabels) { + for (j <- 0 until numLabels) { + total += confusionMatrix(i)(j) + if (i == j) { + correct += confusionMatrix(i)(j) + } + } + } + + 100.0 * correct / total + } + + /** Sum of true positives and false negatives */ + private def getActualNumberOfTestExamplesForClass(label: String): Int = { + val labelId: Int = labelMap(label) + var sum: Int = 0 + for (i <- 0 until numLabels) { + sum += confusionMatrix(labelId)(i) + } + sum + } + + def getPrecision(label: String): Double = { + val labelId: Int = labelMap(label) + val truePositives: Int = confusionMatrix(labelId)(labelId) + var falsePositives: Int = 0 + + for (i <- 0 until numLabels) { + if (i != labelId) { + falsePositives += confusionMatrix(i)(labelId) + } + } + + if (truePositives + falsePositives == 0) { + 0 + } else { + (truePositives.asInstanceOf[Double]) / (truePositives + falsePositives) + } + } + + + def getWeightedPrecision: Double = { + val precisions: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + precisions(index) = getPrecision(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(precisions, weights) + } + + def getRecall(label: String): Double = { + val labelId: Int = labelMap(label) + val truePositives: Int = confusionMatrix(labelId)(labelId) + var falseNegatives: Int = 0 + for (i <- 0 until numLabels) { + if (i != labelId) { + falseNegatives += confusionMatrix(labelId)(i) + } + } + + if (truePositives + falseNegatives == 0) { + 0 + } else { + (truePositives.asInstanceOf[Double]) / (truePositives + falseNegatives) + } + } + + def getWeightedRecall: Double = { + val recalls: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + recalls(index) = getRecall(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(recalls, weights) + } + + def getF1score(label: String): Double = { + val precision: Double = getPrecision(label) + val recall: Double = getRecall(label) + if (precision + recall == 0) { + 0 + } else { + 2 * precision * recall / (precision + recall) + } + } + + def getWeightedF1score: Double = { + val f1Scores: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + f1Scores(index) = getF1score(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(f1Scores, weights) + } + + def getReliability: Double = { + var count: Int = 0 + var accuracy: Double = 0 + for (label <- labelMap.keys) { + if (!(label == defaultLabel)) { + accuracy += getAccuracy(label) + } + count += 1 + } + accuracy / count + } + + /** + * Accuracy v.s. randomly classifying all samples. + * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy()) + * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. + * Educational And Psychological Measurement 20:37-46. + * + * Formula and variable names from: + * http://www.yale.edu/ceo/OEFS/Accuracy.pdf + * + * @return double + */ + def getKappa: Double = { + var a: Double = 0.0 + var b: Double = 0.0 + for (i <- 0 until confusionMatrix.length) { + a += confusionMatrix(i)(i) + var br: Int = 0 + for (j <- 0 until confusionMatrix.length) { + br += confusionMatrix(i)(j) + } + var bc: Int = 0 + //TODO: verify this as an iterator + for (vec <- confusionMatrix) { + bc += vec(i) + } + b += br * bc + } + (samples * a - b) / (samples * samples - b) + } + + def getCorrect(label: String): Int = { + val labelId: Int = labelMap(label) + confusionMatrix(labelId)(labelId) + } + + def getTotal(label: String): Int = { + val labelId: Int = labelMap(label) + var labelTotal: Int = 0 + for (i <- 0 until numLabels) { + labelTotal += confusionMatrix(labelId)(i) + } + labelTotal + } + + /** + * Standard deviation of normalized producer accuracy + * Not a standard score + * @return double + */ + def getNormalizedStats: RunningAverageAndStdDev = { + val summer = new FullRunningAverageAndStdDev() + for (d <- 0 until confusionMatrix.length) { + var total: Double = 0.0 + for (j <- 0 until confusionMatrix.length) { + total += confusionMatrix(d)(j) + } + summer.addDatum(confusionMatrix(d)(d) / (total + 0.000001)) + } + summer + } + + def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Unit = { + samples += 1 + incrementCount(correctLabel, classifiedResult.getLabel) + } + + def addInstance(correctLabel: String, classifiedLabel: String): Unit = { + samples += 1 + incrementCount(correctLabel, classifiedLabel) + } + + def getCount(correctLabel: String, classifiedLabel: String): Int = { + if (!labelMap.containsKey(correctLabel)) { + // LOG.warn("Label {} did not appear in the training examples", correctLabel) + return 0 + } + assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) + val correctId: Int = labelMap(correctLabel) + val classifiedId: Int = labelMap(classifiedLabel) + confusionMatrix(correctId)(classifiedId) + } + + def putCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { + if (!labelMap.containsKey(correctLabel)) { + // LOG.warn("Label {} did not appear in the training examples", correctLabel) + return + } + assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) + val correctId: Int = labelMap(correctLabel) + val classifiedId: Int = labelMap(classifiedLabel) + if (confusionMatrix(correctId)(classifiedId) == 0.0 && count != 0) { + samples += 1 + } + confusionMatrix(correctId)(classifiedId) = count + } + + def incrementCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { + putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel)) + } + + def incrementCount(correctLabel: String, classifiedLabel: String): Unit = { + incrementCount(correctLabel, classifiedLabel, 1) + } + + def getDefaultLabel: String = { + defaultLabel + } + + def merge(b: ConfusionMatrix): ConfusionMatrix = { + assert(labelMap.size == b.getLabels.size, "The label sizes do not match") + for (correctLabel <- this.labelMap.keys) { + for (classifiedLabel <- this.labelMap.keys) { + incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)) + } + } + this + } + + def getMatrix: Matrix = { + val length: Int = confusionMatrix.length + val m: Matrix = new DenseMatrix(length, length) + + val labels: java.util.HashMap[String, Integer] = new java.util.HashMap() + + for (r <- 0 until length) { + for (c <- 0 until length) { + m.set(r, c, confusionMatrix(r)(c)) + } + } + + for (entry <- labelMap.entrySet) { + labels.put(entry.getKey, entry.getValue) + } + m.setRowLabelBindings(labels) + m.setColumnLabelBindings(labels) + + m + } + + def setMatrix(m: Matrix) : Unit = { + val length: Int = confusionMatrix.length + if (m.numRows != m.numCols) { + throw new IllegalArgumentException("ConfusionMatrix: matrix(" + m.numRows + ',' + m.numCols + ") must be square") + } + + for (r <- 0 until length) { + for (c <- 0 until length) { + confusionMatrix(r)(c) = Math.round(m.get(r, c)).toInt + } + } + + var labels = m.getRowLabelBindings + if (labels == null) { + labels = m.getColumnLabelBindings + } + + if (labels != null) { + val sorted: Array[String] = sortLabels(labels) + verifyLabels(length, sorted) + labelMap.clear + for (i <- 0 until length) { + labelMap.put(sorted(i), i) + } + } + } + + def verifyLabels(length: Int, sorted: Array[String]): Unit = { + assert(sorted.length == length, "One label, one row") + for (i <- 0 until length) { + if (sorted(i) == null) { + assert(false, "One label, one row") + } + } + } + + def sortLabels(labels: java.util.Map[String, Integer]): Array[String] = { + val sorted: Array[String] = new Array[String](labels.size) + for (entry <- labels.entrySet) { + sorted(entry.getValue) = entry.getKey + } + + sorted + } + + /** + * This is overloaded. toString() is not a formatted report you print for a manager :) + * Assume that if there are no default assignments, the default feature was not used + */ + override def toString: String = { + + val returnString: StringBuilder = new StringBuilder(200) + + returnString.append("=======================================================").append('\n') + returnString.append("Confusion Matrix\n") + returnString.append("-------------------------------------------------------").append('\n') + + val unclassified: Int = getTotal(defaultLabel) + + for (entry <- this.labelMap.entrySet) { + if (!((entry.getKey == defaultLabel) && unclassified == 0)) { + returnString.append(getSmallLabel(entry.getValue) + " ").append('\t') + } + } + + returnString.append("<--Classified as").append('\n') + + for (entry <- this.labelMap.entrySet) { + if (!((entry.getKey == defaultLabel) && unclassified == 0)) { + val correctLabel: String = entry.getKey + var labelTotal: Int = 0 + + for (classifiedLabel <- this.labelMap.keySet) { + if (!((classifiedLabel == defaultLabel) && unclassified == 0)) { + returnString.append(Integer.toString(getCount(correctLabel, classifiedLabel)) + " ") + .append('\t') + labelTotal += getCount(correctLabel, classifiedLabel) + } + } + returnString.append(" | ").append(String.valueOf(labelTotal) + " ") + .append('\t') + .append(getSmallLabel(entry.getValue) + " ") + .append(" = ") + .append(correctLabel) + .append('\n') + } + } + + if (unclassified > 0) { + returnString.append("Default Category: ") + .append(defaultLabel) + .append(": ") + .append(unclassified) + .append('\n') + } + returnString.append('\n') + + returnString.toString() + } + + + def getSmallLabel(i: Int): String = { + var value: Int = i + val returnString: StringBuilder = new StringBuilder + do { + val n: Int = value % 26 + returnString.insert(0, ('a' + n).asInstanceOf[Char]) + value /= 26 + } while (value > 0) + + returnString.toString() + } + + +}
