http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala b/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala deleted file mode 100644 index 2e3eadd..0000000 --- a/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala +++ /dev/null @@ -1,132 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.engine - -import io.prediction.e2.fixture.{NaiveBayesFixture, SharedSparkContext} -import org.scalatest.{Matchers, FlatSpec} - -import scala.language.reflectiveCalls - -class CategoricalNaiveBayesTest extends FlatSpec with Matchers -with SharedSparkContext with NaiveBayesFixture { - val Tolerance = .0001 - val labeledPoints = fruit.labeledPoints - - "Model" should "have log priors and log likelihoods" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - model.priors(fruit.Banana) should be(-.7885 +- Tolerance) - model.priors(fruit.Orange) should be(-1.7047 +- Tolerance) - model.priors(fruit.OtherFruit) should be(-1.0116 +- Tolerance) - - model.likelihoods(fruit.Banana)(0)(fruit.Long) should - be(-.2231 +- Tolerance) - model.likelihoods(fruit.Banana)(0)(fruit.NotLong) should - be(-1.6094 +- Tolerance) - model.likelihoods(fruit.Banana)(1)(fruit.Sweet) should - be(-.2231 +- Tolerance) - model.likelihoods(fruit.Banana)(1)(fruit.NotSweet) should - be(-1.6094 +- Tolerance) - model.likelihoods(fruit.Banana)(2)(fruit.Yellow) should - be(-.2231 +- Tolerance) - model.likelihoods(fruit.Banana)(2)(fruit.NotYellow) should - be(-1.6094 +- Tolerance) - - model.likelihoods(fruit.Orange)(0) should not contain key(fruit.Long) - model.likelihoods(fruit.Orange)(0)(fruit.NotLong) should be(0.0) - model.likelihoods(fruit.Orange)(1)(fruit.Sweet) should - be(-.6931 +- Tolerance) - model.likelihoods(fruit.Orange)(1)(fruit.NotSweet) should - be(-.6931 +- Tolerance) - model.likelihoods(fruit.Orange)(2)(fruit.NotYellow) should be(0.0) - model.likelihoods(fruit.Orange)(2) should not contain key(fruit.Yellow) - - model.likelihoods(fruit.OtherFruit)(0)(fruit.Long) should - be(-.6931 +- Tolerance) - model.likelihoods(fruit.OtherFruit)(0)(fruit.NotLong) should - be(-.6931 +- Tolerance) - model.likelihoods(fruit.OtherFruit)(1)(fruit.Sweet) should - be(-.2877 +- Tolerance) - model.likelihoods(fruit.OtherFruit)(1)(fruit.NotSweet) should - be(-1.3863 +- Tolerance) - model.likelihoods(fruit.OtherFruit)(2)(fruit.Yellow) should - be(-1.3863 +- Tolerance) - model.likelihoods(fruit.OtherFruit)(2)(fruit.NotYellow) should - be(-.2877 +- Tolerance) - } - - "Model's log score" should "be the log score of the given point" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - val score = model.logScore(LabeledPoint( - fruit.Banana, - Array(fruit.Long, fruit.NotSweet, fruit.NotYellow)) - ) - - score should not be None - score.get should be(-4.2304 +- Tolerance) - } - - it should "be negative infinity for a point with a non-existing feature" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - val score = model.logScore(LabeledPoint( - fruit.Banana, - Array(fruit.Long, fruit.NotSweet, "Not Exist")) - ) - - score should not be None - score.get should be(Double.NegativeInfinity) - } - - it should "be none for a point with a non-existing label" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - val score = model.logScore(LabeledPoint( - "Not Exist", - Array(fruit.Long, fruit.NotSweet, fruit.Yellow)) - ) - - score should be(None) - } - - it should "use the provided default likelihood function" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - val score = model.logScore( - LabeledPoint( - fruit.Banana, - Array(fruit.Long, fruit.NotSweet, "Not Exist") - ), - ls => ls.min - math.log(2) - ) - - score should not be None - score.get should be(-4.9236 +- Tolerance) - } - - "Model predict" should "return the correct label" in { - val labeledPointsRdd = sc.parallelize(labeledPoints) - val model = CategoricalNaiveBayes.train(labeledPointsRdd) - - val label = model.predict(Array(fruit.Long, fruit.Sweet, fruit.Yellow)) - label should be(fruit.Banana) - } -}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala b/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala deleted file mode 100644 index a33a30a..0000000 --- a/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -package io.prediction.e2.engine - -import io.prediction.e2.fixture.{MarkovChainFixture, SharedSparkContext} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix -import org.scalatest.{FlatSpec, Matchers} - -import scala.language.reflectiveCalls - -class MarkovChainTest extends FlatSpec with Matchers with SharedSparkContext -with MarkovChainFixture { - - "Markov chain training" should "produce a model" in { - val matrix = - new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries)) - val model = MarkovChain.train(matrix, 2) - - model.n should be(2) - model.transitionVectors.collect() should contain theSameElementsAs Seq( - (0, Vectors.sparse(2, Array(0, 1), Array(0.3, 0.7))), - (1, Vectors.sparse(2, Array(0, 1), Array(0.5, 0.5))) - ) - } - - it should "contains probabilities of the top N only" in { - val matrix = - new CoordinateMatrix(sc.parallelize(fiveByFiveMatrix.matrixEntries)) - val model = MarkovChain.train(matrix, 2) - - model.n should be(2) - (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))) - model.transitionVectors.collect() should contain theSameElementsAs Seq( - (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))), - (1, Vectors.sparse(5, Array(2, 4), Array(9.0 / 25, 8.0 / 25))), - (2, Vectors.sparse(5, Array(1, 4), Array(10.0 / 28, 10.0 / 28))), - (3, Vectors.sparse(5, Array(3, 4), Array(3.0 / 9, 4.0 / 9))), - (4, Vectors.sparse(5, Array(3, 4), Array(8.0 / 25, 0.4))) - ) - } - - "Model predict" should "calculate the probablities of new states" in { - val matrix = - new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries)) - val model = MarkovChain.train(matrix, 2) - val nextState = model.predict(Seq(0.4, 0.6)) - - nextState should contain theSameElementsInOrderAs Seq(0.42, 0.58) - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala b/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala deleted file mode 100644 index ead51b2..0000000 --- a/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala +++ /dev/null @@ -1,111 +0,0 @@ -package io.prediction.e2.evaluation - -import org.scalatest.{Matchers, Inspectors, FlatSpec} -import org.apache.spark.rdd.RDD -import io.prediction.e2.fixture.SharedSparkContext -import io.prediction.e2.engine.LabeledPoint - -object CrossValidationTest { - case class TrainingData(labeledPoints: Seq[LabeledPoint]) - case class Query(features: Array[String]) - case class ActualResult(label: String) - - case class EmptyEvaluationParams() - - def toTrainingData(labeledPoints: RDD[LabeledPoint]) = TrainingData(labeledPoints.collect().toSeq) - def toQuery(labeledPoint: LabeledPoint) = Query(labeledPoint.features) - def toActualResult(labeledPoint: LabeledPoint) = ActualResult(labeledPoint.label) - -} - - -class CrossValidationTest extends FlatSpec with Matchers with Inspectors -with SharedSparkContext{ - - - val Label1 = "l1" - val Label2 = "l2" - val Label3 = "l3" - val Label4 = "l4" - val Attribute1 = "a1" - val NotAttribute1 = "na1" - val Attribute2 = "a2" - val NotAttribute2 = "na2" - - val labeledPoints = Seq( - LabeledPoint(Label1, Array(Attribute1, Attribute2)), - LabeledPoint(Label2, Array(NotAttribute1, Attribute2)), - LabeledPoint(Label3, Array(Attribute1, NotAttribute2)), - LabeledPoint(Label4, Array(NotAttribute1, NotAttribute2)) - ) - - val dataCount = labeledPoints.size - val evalKs = (1 to dataCount) - val emptyParams = new CrossValidationTest.EmptyEvaluationParams() - type Fold = ( - CrossValidationTest.TrainingData, - CrossValidationTest.EmptyEvaluationParams, - RDD[(CrossValidationTest.Query, CrossValidationTest.ActualResult)]) - - def toTestTrain(dataSplit: Fold): (Seq[LabeledPoint], Seq[LabeledPoint]) = { - val trainingData = dataSplit._1.labeledPoints - val queryActual = dataSplit._3 - val testingData = queryActual.map { case (query, actual) => - LabeledPoint(actual.label, query.features) - } - (trainingData, testingData.collect().toSeq) - } - - def splitData(k: Int, labeledPointsRDD: RDD[LabeledPoint]): Seq[Fold] = { - CommonHelperFunctions.splitData[ - LabeledPoint, - CrossValidationTest.TrainingData, - CrossValidationTest.EmptyEvaluationParams, - CrossValidationTest.Query, - CrossValidationTest.ActualResult]( - k, - labeledPointsRDD, - emptyParams, - CrossValidationTest.toTrainingData, - CrossValidationTest.toQuery, - CrossValidationTest.toActualResult) - } - - "Fold count" should "equal evalK" in { - val labeledPointsRDD = sc.parallelize(labeledPoints) - val lengths = evalKs.map(k => splitData(k, labeledPointsRDD).length) - lengths should be(evalKs) - } - - "Testing data size" should "be within 1 of total / evalK" in { - val labeledPointsRDD = sc.parallelize(labeledPoints) - val splits = evalKs.map(k => k -> splitData(k, labeledPointsRDD)) - val diffs = splits.map { case (k, folds) => - folds.map(fold => fold._3.count() - dataCount / k) - } - forAll(diffs) {foldDiffs => foldDiffs.max should be <= 1L} - diffs.map(folds => folds.sum) should be(evalKs.map(k => dataCount % k)) - } - - "Training + testing" should "equal original dataset" in { - val labeledPointsRDD = sc.parallelize(labeledPoints) - forAll(evalKs) {k => - val split = splitData(k, labeledPointsRDD) - forAll(split) {fold => - val(training, testing) = toTestTrain(fold) - (training ++ testing).toSet should be(labeledPoints.toSet) - } - } - } - - "Training and testing" should "be disjoint" in { - val labeledPointsRDD = sc.parallelize(labeledPoints) - forAll(evalKs) { k => - val split = splitData(k, labeledPointsRDD) - forAll(split) { fold => - val (training, testing) = toTestTrain(fold) - training.toSet.intersect(testing.toSet) should be('empty) - } - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala deleted file mode 100644 index 56ebbd8..0000000 --- a/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala +++ /dev/null @@ -1,59 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.e2.fixture - -import scala.collection.immutable.HashMap -import scala.collection.immutable.HashSet -import org.apache.spark.mllib.linalg.Vector - -trait BinaryVectorizerFixture { - - def base = { - new { - val maps : Seq[HashMap[String, String]] = Seq( - HashMap("food" -> "orange", "music" -> "rock", "hobby" -> "scala"), - HashMap("food" -> "orange", "music" -> "pop", "hobby" ->"running"), - HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar"), - HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar") - ) - - val properties = HashSet("food", "hobby") - } - } - - - def testArrays = { - new { - // Test case for checking food value not listed in base.maps, and - // property not in properties. - val one = Array(("food", "burger"), ("music", "rock"), ("hobby", "scala")) - - // Test case for making sure indices are preserved. - val twoA = Array(("food", "orange"), ("hobby", "scala")) - val twoB = Array(("food", "banana"), ("hobby", "scala")) - val twoC = Array(("hobby", "guitar")) - } - } - - def vecSum (vec1 : Vector, vec2 : Vector) : Array[Double] = { - (0 until vec1.size).map( - k => vec1(k) + vec2(k) - ).toArray - } - -} - - http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala deleted file mode 100644 index e47d49e..0000000 --- a/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala +++ /dev/null @@ -1,39 +0,0 @@ -package io.prediction.e2.fixture - -import org.apache.spark.mllib.linalg.distributed.MatrixEntry - -trait MarkovChainFixture { - def twoByTwoMatrix = { - new { - val matrixEntries = Seq( - MatrixEntry(0, 0, 3), - MatrixEntry(0, 1, 7), - MatrixEntry(1, 0, 10), - MatrixEntry(1, 1, 10) - ) - } - } - - def fiveByFiveMatrix = { - new { - val matrixEntries = Seq( - MatrixEntry(0, 1, 12), - MatrixEntry(0, 2, 8), - MatrixEntry(1, 0, 3), - MatrixEntry(1, 1, 3), - MatrixEntry(1, 2, 9), - MatrixEntry(1, 3, 2), - MatrixEntry(1, 4, 8), - MatrixEntry(2, 1, 10), - MatrixEntry(2, 2, 8), - MatrixEntry(2, 4, 10), - MatrixEntry(3, 0, 2), - MatrixEntry(3, 3, 3), - MatrixEntry(3, 4, 4), - MatrixEntry(4, 1, 7), - MatrixEntry(4, 3, 8), - MatrixEntry(4, 4, 10) - ) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala deleted file mode 100644 index 97dd663..0000000 --- a/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala +++ /dev/null @@ -1,48 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.fixture - -import io.prediction.e2.engine.LabeledPoint - -trait NaiveBayesFixture { - - def fruit = { - new { - val Banana = "Banana" - val Orange = "Orange" - val OtherFruit = "Other Fruit" - val NotLong = "Not Long" - val Long = "Long" - val NotSweet = "Not Sweet" - val Sweet = "Sweet" - val NotYellow = "Not Yellow" - val Yellow = "Yellow" - - val labeledPoints = Seq( - LabeledPoint(Banana, Array(Long, Sweet, Yellow)), - LabeledPoint(Banana, Array(Long, Sweet, Yellow)), - LabeledPoint(Banana, Array(Long, Sweet, Yellow)), - LabeledPoint(Banana, Array(Long, Sweet, Yellow)), - LabeledPoint(Banana, Array(NotLong, NotSweet, NotYellow)), - LabeledPoint(Orange, Array(NotLong, Sweet, NotYellow)), - LabeledPoint(Orange, Array(NotLong, NotSweet, NotYellow)), - LabeledPoint(OtherFruit, Array(Long, Sweet, NotYellow)), - LabeledPoint(OtherFruit, Array(NotLong, Sweet, NotYellow)), - LabeledPoint(OtherFruit, Array(Long, Sweet, Yellow)), - LabeledPoint(OtherFruit, Array(NotLong, NotSweet, NotYellow)) - ) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala b/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala deleted file mode 100644 index 74dd814..0000000 --- a/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala +++ /dev/null @@ -1,51 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.fixture - -import org.apache.spark.{SparkConf, SparkContext} -import org.scalatest.{BeforeAndAfterAll, Suite} - -trait SharedSparkContext extends BeforeAndAfterAll { - self: Suite => - @transient private var _sc: SparkContext = _ - - def sc: SparkContext = _sc - - var conf = new SparkConf(false) - - override def beforeAll() { - _sc = new SparkContext("local", "test", conf) - super.beforeAll() - } - - override def afterAll() { - LocalSparkContext.stop(_sc) - - _sc = null - super.afterAll() - } -} - -object LocalSparkContext { - def stop(sc: SparkContext) { - if (sc != null) { - sc.stop() - } - // To avoid Akka rebinding to the same port, since it doesn't unbind - // immediately on shutdown - System.clearProperty("spark.driver.port") - } -} - http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala new file mode 100644 index 0000000..576b8c6 --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala @@ -0,0 +1,56 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.e2.engine + +import org.apache.predictionio.e2.fixture.BinaryVectorizerFixture +import org.apache.predictionio.e2.fixture.SharedSparkContext +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.rdd.RDD +import org.scalatest.FlatSpec +import org.scalatest.Matchers +import scala.collection.immutable.HashMap + + +import scala.language.reflectiveCalls + +class BinaryVectorizerTest extends FlatSpec with Matchers with SharedSparkContext +with BinaryVectorizerFixture{ + + "toBinary" should "produce the following summed values:" in { + val testCase = BinaryVectorizer(sc.parallelize(base.maps), base.properties) + val vectorTwoA = testCase.toBinary(testArrays.twoA) + val vectorTwoB = testCase.toBinary(testArrays.twoB) + + + // Make sure vectors produced are the same size. + vectorTwoA.size should be (vectorTwoB.size) + + // // Test case for checking food value not listed in base.maps. + testCase.toBinary(testArrays.one).toArray.sum should be (1.0) + + // Test cases for making sure indices are preserved. + val sumOne = vecSum(vectorTwoA, vectorTwoB) + + exactly (1, sumOne) should be (2.0) + exactly (2,sumOne) should be (0.0) + exactly (2, sumOne) should be (1.0) + + val sumTwo = vecSum(Vectors.dense(sumOne), testCase.toBinary(testArrays.twoC)) + + exactly (3, sumTwo) should be (1.0) + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala new file mode 100644 index 0000000..4373d7d --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala @@ -0,0 +1,132 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.engine + +import org.apache.predictionio.e2.fixture.{NaiveBayesFixture, SharedSparkContext} +import org.scalatest.{Matchers, FlatSpec} + +import scala.language.reflectiveCalls + +class CategoricalNaiveBayesTest extends FlatSpec with Matchers +with SharedSparkContext with NaiveBayesFixture { + val Tolerance = .0001 + val labeledPoints = fruit.labeledPoints + + "Model" should "have log priors and log likelihoods" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + model.priors(fruit.Banana) should be(-.7885 +- Tolerance) + model.priors(fruit.Orange) should be(-1.7047 +- Tolerance) + model.priors(fruit.OtherFruit) should be(-1.0116 +- Tolerance) + + model.likelihoods(fruit.Banana)(0)(fruit.Long) should + be(-.2231 +- Tolerance) + model.likelihoods(fruit.Banana)(0)(fruit.NotLong) should + be(-1.6094 +- Tolerance) + model.likelihoods(fruit.Banana)(1)(fruit.Sweet) should + be(-.2231 +- Tolerance) + model.likelihoods(fruit.Banana)(1)(fruit.NotSweet) should + be(-1.6094 +- Tolerance) + model.likelihoods(fruit.Banana)(2)(fruit.Yellow) should + be(-.2231 +- Tolerance) + model.likelihoods(fruit.Banana)(2)(fruit.NotYellow) should + be(-1.6094 +- Tolerance) + + model.likelihoods(fruit.Orange)(0) should not contain key(fruit.Long) + model.likelihoods(fruit.Orange)(0)(fruit.NotLong) should be(0.0) + model.likelihoods(fruit.Orange)(1)(fruit.Sweet) should + be(-.6931 +- Tolerance) + model.likelihoods(fruit.Orange)(1)(fruit.NotSweet) should + be(-.6931 +- Tolerance) + model.likelihoods(fruit.Orange)(2)(fruit.NotYellow) should be(0.0) + model.likelihoods(fruit.Orange)(2) should not contain key(fruit.Yellow) + + model.likelihoods(fruit.OtherFruit)(0)(fruit.Long) should + be(-.6931 +- Tolerance) + model.likelihoods(fruit.OtherFruit)(0)(fruit.NotLong) should + be(-.6931 +- Tolerance) + model.likelihoods(fruit.OtherFruit)(1)(fruit.Sweet) should + be(-.2877 +- Tolerance) + model.likelihoods(fruit.OtherFruit)(1)(fruit.NotSweet) should + be(-1.3863 +- Tolerance) + model.likelihoods(fruit.OtherFruit)(2)(fruit.Yellow) should + be(-1.3863 +- Tolerance) + model.likelihoods(fruit.OtherFruit)(2)(fruit.NotYellow) should + be(-.2877 +- Tolerance) + } + + "Model's log score" should "be the log score of the given point" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + val score = model.logScore(LabeledPoint( + fruit.Banana, + Array(fruit.Long, fruit.NotSweet, fruit.NotYellow)) + ) + + score should not be None + score.get should be(-4.2304 +- Tolerance) + } + + it should "be negative infinity for a point with a non-existing feature" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + val score = model.logScore(LabeledPoint( + fruit.Banana, + Array(fruit.Long, fruit.NotSweet, "Not Exist")) + ) + + score should not be None + score.get should be(Double.NegativeInfinity) + } + + it should "be none for a point with a non-existing label" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + val score = model.logScore(LabeledPoint( + "Not Exist", + Array(fruit.Long, fruit.NotSweet, fruit.Yellow)) + ) + + score should be(None) + } + + it should "use the provided default likelihood function" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + val score = model.logScore( + LabeledPoint( + fruit.Banana, + Array(fruit.Long, fruit.NotSweet, "Not Exist") + ), + ls => ls.min - math.log(2) + ) + + score should not be None + score.get should be(-4.9236 +- Tolerance) + } + + "Model predict" should "return the correct label" in { + val labeledPointsRdd = sc.parallelize(labeledPoints) + val model = CategoricalNaiveBayes.train(labeledPointsRdd) + + val label = model.predict(Array(fruit.Long, fruit.Sweet, fruit.Yellow)) + label should be(fruit.Banana) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala new file mode 100644 index 0000000..137095a --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala @@ -0,0 +1,49 @@ +package org.apache.predictionio.e2.engine + +import org.apache.predictionio.e2.fixture.{MarkovChainFixture, SharedSparkContext} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix +import org.scalatest.{FlatSpec, Matchers} + +import scala.language.reflectiveCalls + +class MarkovChainTest extends FlatSpec with Matchers with SharedSparkContext +with MarkovChainFixture { + + "Markov chain training" should "produce a model" in { + val matrix = + new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries)) + val model = MarkovChain.train(matrix, 2) + + model.n should be(2) + model.transitionVectors.collect() should contain theSameElementsAs Seq( + (0, Vectors.sparse(2, Array(0, 1), Array(0.3, 0.7))), + (1, Vectors.sparse(2, Array(0, 1), Array(0.5, 0.5))) + ) + } + + it should "contains probabilities of the top N only" in { + val matrix = + new CoordinateMatrix(sc.parallelize(fiveByFiveMatrix.matrixEntries)) + val model = MarkovChain.train(matrix, 2) + + model.n should be(2) + (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))) + model.transitionVectors.collect() should contain theSameElementsAs Seq( + (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))), + (1, Vectors.sparse(5, Array(2, 4), Array(9.0 / 25, 8.0 / 25))), + (2, Vectors.sparse(5, Array(1, 4), Array(10.0 / 28, 10.0 / 28))), + (3, Vectors.sparse(5, Array(3, 4), Array(3.0 / 9, 4.0 / 9))), + (4, Vectors.sparse(5, Array(3, 4), Array(8.0 / 25, 0.4))) + ) + } + + "Model predict" should "calculate the probablities of new states" in { + val matrix = + new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries)) + val model = MarkovChain.train(matrix, 2) + val nextState = model.predict(Seq(0.4, 0.6)) + + nextState should contain theSameElementsInOrderAs Seq(0.42, 0.58) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala new file mode 100644 index 0000000..d15b927 --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala @@ -0,0 +1,111 @@ +package org.apache.predictionio.e2.evaluation + +import org.scalatest.{Matchers, Inspectors, FlatSpec} +import org.apache.spark.rdd.RDD +import org.apache.predictionio.e2.fixture.SharedSparkContext +import org.apache.predictionio.e2.engine.LabeledPoint + +object CrossValidationTest { + case class TrainingData(labeledPoints: Seq[LabeledPoint]) + case class Query(features: Array[String]) + case class ActualResult(label: String) + + case class EmptyEvaluationParams() + + def toTrainingData(labeledPoints: RDD[LabeledPoint]) = TrainingData(labeledPoints.collect().toSeq) + def toQuery(labeledPoint: LabeledPoint) = Query(labeledPoint.features) + def toActualResult(labeledPoint: LabeledPoint) = ActualResult(labeledPoint.label) + +} + + +class CrossValidationTest extends FlatSpec with Matchers with Inspectors +with SharedSparkContext{ + + + val Label1 = "l1" + val Label2 = "l2" + val Label3 = "l3" + val Label4 = "l4" + val Attribute1 = "a1" + val NotAttribute1 = "na1" + val Attribute2 = "a2" + val NotAttribute2 = "na2" + + val labeledPoints = Seq( + LabeledPoint(Label1, Array(Attribute1, Attribute2)), + LabeledPoint(Label2, Array(NotAttribute1, Attribute2)), + LabeledPoint(Label3, Array(Attribute1, NotAttribute2)), + LabeledPoint(Label4, Array(NotAttribute1, NotAttribute2)) + ) + + val dataCount = labeledPoints.size + val evalKs = (1 to dataCount) + val emptyParams = new CrossValidationTest.EmptyEvaluationParams() + type Fold = ( + CrossValidationTest.TrainingData, + CrossValidationTest.EmptyEvaluationParams, + RDD[(CrossValidationTest.Query, CrossValidationTest.ActualResult)]) + + def toTestTrain(dataSplit: Fold): (Seq[LabeledPoint], Seq[LabeledPoint]) = { + val trainingData = dataSplit._1.labeledPoints + val queryActual = dataSplit._3 + val testingData = queryActual.map { case (query, actual) => + LabeledPoint(actual.label, query.features) + } + (trainingData, testingData.collect().toSeq) + } + + def splitData(k: Int, labeledPointsRDD: RDD[LabeledPoint]): Seq[Fold] = { + CommonHelperFunctions.splitData[ + LabeledPoint, + CrossValidationTest.TrainingData, + CrossValidationTest.EmptyEvaluationParams, + CrossValidationTest.Query, + CrossValidationTest.ActualResult]( + k, + labeledPointsRDD, + emptyParams, + CrossValidationTest.toTrainingData, + CrossValidationTest.toQuery, + CrossValidationTest.toActualResult) + } + + "Fold count" should "equal evalK" in { + val labeledPointsRDD = sc.parallelize(labeledPoints) + val lengths = evalKs.map(k => splitData(k, labeledPointsRDD).length) + lengths should be(evalKs) + } + + "Testing data size" should "be within 1 of total / evalK" in { + val labeledPointsRDD = sc.parallelize(labeledPoints) + val splits = evalKs.map(k => k -> splitData(k, labeledPointsRDD)) + val diffs = splits.map { case (k, folds) => + folds.map(fold => fold._3.count() - dataCount / k) + } + forAll(diffs) {foldDiffs => foldDiffs.max should be <= 1L} + diffs.map(folds => folds.sum) should be(evalKs.map(k => dataCount % k)) + } + + "Training + testing" should "equal original dataset" in { + val labeledPointsRDD = sc.parallelize(labeledPoints) + forAll(evalKs) {k => + val split = splitData(k, labeledPointsRDD) + forAll(split) {fold => + val(training, testing) = toTestTrain(fold) + (training ++ testing).toSet should be(labeledPoints.toSet) + } + } + } + + "Training and testing" should "be disjoint" in { + val labeledPointsRDD = sc.parallelize(labeledPoints) + forAll(evalKs) { k => + val split = splitData(k, labeledPointsRDD) + forAll(split) { fold => + val (training, testing) = toTestTrain(fold) + training.toSet.intersect(testing.toSet) should be('empty) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala new file mode 100644 index 0000000..76d8db3 --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala @@ -0,0 +1,59 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.e2.fixture + +import scala.collection.immutable.HashMap +import scala.collection.immutable.HashSet +import org.apache.spark.mllib.linalg.Vector + +trait BinaryVectorizerFixture { + + def base = { + new { + val maps : Seq[HashMap[String, String]] = Seq( + HashMap("food" -> "orange", "music" -> "rock", "hobby" -> "scala"), + HashMap("food" -> "orange", "music" -> "pop", "hobby" ->"running"), + HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar"), + HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar") + ) + + val properties = HashSet("food", "hobby") + } + } + + + def testArrays = { + new { + // Test case for checking food value not listed in base.maps, and + // property not in properties. + val one = Array(("food", "burger"), ("music", "rock"), ("hobby", "scala")) + + // Test case for making sure indices are preserved. + val twoA = Array(("food", "orange"), ("hobby", "scala")) + val twoB = Array(("food", "banana"), ("hobby", "scala")) + val twoC = Array(("hobby", "guitar")) + } + } + + def vecSum (vec1 : Vector, vec2 : Vector) : Array[Double] = { + (0 until vec1.size).map( + k => vec1(k) + vec2(k) + ).toArray + } + +} + + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala new file mode 100644 index 0000000..a214be0 --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala @@ -0,0 +1,39 @@ +package org.apache.predictionio.e2.fixture + +import org.apache.spark.mllib.linalg.distributed.MatrixEntry + +trait MarkovChainFixture { + def twoByTwoMatrix = { + new { + val matrixEntries = Seq( + MatrixEntry(0, 0, 3), + MatrixEntry(0, 1, 7), + MatrixEntry(1, 0, 10), + MatrixEntry(1, 1, 10) + ) + } + } + + def fiveByFiveMatrix = { + new { + val matrixEntries = Seq( + MatrixEntry(0, 1, 12), + MatrixEntry(0, 2, 8), + MatrixEntry(1, 0, 3), + MatrixEntry(1, 1, 3), + MatrixEntry(1, 2, 9), + MatrixEntry(1, 3, 2), + MatrixEntry(1, 4, 8), + MatrixEntry(2, 1, 10), + MatrixEntry(2, 2, 8), + MatrixEntry(2, 4, 10), + MatrixEntry(3, 0, 2), + MatrixEntry(3, 3, 3), + MatrixEntry(3, 4, 4), + MatrixEntry(4, 1, 7), + MatrixEntry(4, 3, 8), + MatrixEntry(4, 4, 10) + ) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala new file mode 100644 index 0000000..483f366 --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala @@ -0,0 +1,48 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.fixture + +import org.apache.predictionio.e2.engine.LabeledPoint + +trait NaiveBayesFixture { + + def fruit = { + new { + val Banana = "Banana" + val Orange = "Orange" + val OtherFruit = "Other Fruit" + val NotLong = "Not Long" + val Long = "Long" + val NotSweet = "Not Sweet" + val Sweet = "Sweet" + val NotYellow = "Not Yellow" + val Yellow = "Yellow" + + val labeledPoints = Seq( + LabeledPoint(Banana, Array(Long, Sweet, Yellow)), + LabeledPoint(Banana, Array(Long, Sweet, Yellow)), + LabeledPoint(Banana, Array(Long, Sweet, Yellow)), + LabeledPoint(Banana, Array(Long, Sweet, Yellow)), + LabeledPoint(Banana, Array(NotLong, NotSweet, NotYellow)), + LabeledPoint(Orange, Array(NotLong, Sweet, NotYellow)), + LabeledPoint(Orange, Array(NotLong, NotSweet, NotYellow)), + LabeledPoint(OtherFruit, Array(Long, Sweet, NotYellow)), + LabeledPoint(OtherFruit, Array(NotLong, Sweet, NotYellow)), + LabeledPoint(OtherFruit, Array(Long, Sweet, Yellow)), + LabeledPoint(OtherFruit, Array(NotLong, NotSweet, NotYellow)) + ) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala new file mode 100644 index 0000000..d0d762e --- /dev/null +++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala @@ -0,0 +1,51 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.fixture + +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.{BeforeAndAfterAll, Suite} + +trait SharedSparkContext extends BeforeAndAfterAll { + self: Suite => + @transient private var _sc: SparkContext = _ + + def sc: SparkContext = _sc + + var conf = new SparkConf(false) + + override def beforeAll() { + _sc = new SparkContext("local", "test", conf) + super.beforeAll() + } + + override def afterAll() { + LocalSparkContext.stop(_sc) + + _sc = null + super.afterAll() + } +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + if (sc != null) { + sc.stop() + } + // To avoid Akka rebinding to the same port, since it doesn't unbind + // immediately on shutdown + System.clearProperty("spark.driver.port") + } +} + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala b/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala deleted file mode 100644 index 74324c9..0000000 --- a/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala +++ /dev/null @@ -1,84 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools - -import java.io.File - -import grizzled.slf4j.Logging -import io.prediction.data.storage.EngineManifest -import io.prediction.data.storage.EngineManifestSerializer -import io.prediction.data.storage.Storage -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path -import org.json4s._ -import org.json4s.native.Serialization.read - -import scala.io.Source - -object RegisterEngine extends Logging { - val engineManifests = Storage.getMetaDataEngineManifests - implicit val formats = DefaultFormats + new EngineManifestSerializer - - def registerEngine( - jsonManifest: File, - engineFiles: Seq[File], - copyLocal: Boolean = false): Unit = { - val jsonString = try { - Source.fromFile(jsonManifest).mkString - } catch { - case e: java.io.FileNotFoundException => - error(s"Engine manifest file not found: ${e.getMessage}. Aborting.") - sys.exit(1) - } - val engineManifest = read[EngineManifest](jsonString) - - info(s"Registering engine ${engineManifest.id} ${engineManifest.version}") - engineManifests.update( - engineManifest.copy(files = engineFiles.map(_.toURI.toString)), true) - } - - def unregisterEngine(jsonManifest: File): Unit = { - val jsonString = try { - Source.fromFile(jsonManifest).mkString - } catch { - case e: java.io.FileNotFoundException => - error(s"Engine manifest file not found: ${e.getMessage}. Aborting.") - sys.exit(1) - } - val fileEngineManifest = read[EngineManifest](jsonString) - val engineManifest = engineManifests.get( - fileEngineManifest.id, - fileEngineManifest.version) - - engineManifest map { em => - val conf = new Configuration - val fs = FileSystem.get(conf) - - em.files foreach { f => - val path = new Path(f) - info(s"Removing ${f}") - fs.delete(path, false) - } - - engineManifests.delete(em.id, em.version) - info(s"Unregistered engine ${em.id} ${em.version}") - } getOrElse { - error(s"${fileEngineManifest.id} ${fileEngineManifest.version} is not " + - "registered.") - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RunServer.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/RunServer.scala b/tools/src/main/scala/io/prediction/tools/RunServer.scala deleted file mode 100644 index eb65e87..0000000 --- a/tools/src/main/scala/io/prediction/tools/RunServer.scala +++ /dev/null @@ -1,178 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools - -import java.io.File -import java.net.URI - -import grizzled.slf4j.Logging -import io.prediction.data.storage.EngineManifest -import io.prediction.tools.console.ConsoleArgs -import io.prediction.workflow.WorkflowUtils - -import scala.sys.process._ - -object RunServer extends Logging { - def runServer( - ca: ConsoleArgs, - core: File, - em: EngineManifest, - engineInstanceId: String): Int = { - val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv => - s"${kv._1}=${kv._2}" - ).mkString(",") - - val sparkHome = ca.common.sparkHome.getOrElse( - sys.env.getOrElse("SPARK_HOME", ".")) - - val extraFiles = WorkflowUtils.thirdPartyConfFiles - - val driverClassPathIndex = - ca.common.sparkPassThrough.indexOf("--driver-class-path") - val driverClassPathPrefix = - if (driverClassPathIndex != -1) { - Seq(ca.common.sparkPassThrough(driverClassPathIndex + 1)) - } else { - Seq() - } - val extraClasspaths = - driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths - - val deployModeIndex = - ca.common.sparkPassThrough.indexOf("--deploy-mode") - val deployMode = if (deployModeIndex != -1) { - ca.common.sparkPassThrough(deployModeIndex + 1) - } else { - "client" - } - - val mainJar = - if (ca.build.uberJar) { - if (deployMode == "cluster") { - em.files.filter(_.startsWith("hdfs")).head - } else { - em.files.filterNot(_.startsWith("hdfs")).head - } - } else { - if (deployMode == "cluster") { - em.files.filter(_.contains("pio-assembly")).head - } else { - core.getCanonicalPath - } - } - - val jarFiles = (em.files ++ Option(new File(ca.common.pioHome.get, "plugins") - .listFiles()).getOrElse(Array.empty[File]).map(_.getAbsolutePath)).mkString(",") - - val sparkSubmit = - Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) ++ - ca.common.sparkPassThrough ++ - Seq( - "--class", - "io.prediction.workflow.CreateServer", - "--name", - s"PredictionIO Engine Instance: ${engineInstanceId}") ++ - (if (!ca.build.uberJar) { - Seq("--jars", jarFiles) - } else Seq()) ++ - (if (extraFiles.size > 0) { - Seq("--files", extraFiles.mkString(",")) - } else { - Seq() - }) ++ - (if (extraClasspaths.size > 0) { - Seq("--driver-class-path", extraClasspaths.mkString(":")) - } else { - Seq() - }) ++ - (if (ca.common.sparkKryo) { - Seq( - "--conf", - "spark.serializer=org.apache.spark.serializer.KryoSerializer") - } else { - Seq() - }) ++ - Seq( - mainJar, - "--engineInstanceId", - engineInstanceId, - "--ip", - ca.deploy.ip, - "--port", - ca.deploy.port.toString, - "--event-server-ip", - ca.eventServer.ip, - "--event-server-port", - ca.eventServer.port.toString) ++ - (if (ca.accessKey.accessKey != "") { - Seq("--accesskey", ca.accessKey.accessKey) - } else { - Seq() - }) ++ - (if (ca.eventServer.enabled) Seq("--feedback") else Seq()) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ - (if (ca.common.verbose) Seq("--verbose") else Seq()) ++ - ca.deploy.logUrl.map(x => Seq("--log-url", x)).getOrElse(Seq()) ++ - ca.deploy.logPrefix.map(x => Seq("--log-prefix", x)).getOrElse(Seq()) ++ - Seq("--json-extractor", ca.common.jsonExtractor.toString) - - info(s"Submission command: ${sparkSubmit.mkString(" ")}") - - val proc = - Process(sparkSubmit, None, "CLASSPATH" -> "", "SPARK_YARN_USER_ENV" -> pioEnvVars).run() - Runtime.getRuntime.addShutdownHook(new Thread(new Runnable { - def run(): Unit = { - proc.destroy() - } - })) - proc.exitValue() - } - - def newRunServer( - ca: ConsoleArgs, - em: EngineManifest, - engineInstanceId: String): Int = { - val jarFiles = em.files.map(new URI(_)) ++ - Option(new File(ca.common.pioHome.get, "plugins").listFiles()) - .getOrElse(Array.empty[File]).map(_.toURI) - val args = Seq( - "--engineInstanceId", - engineInstanceId, - "--engine-variant", - ca.common.variantJson.toURI.toString, - "--ip", - ca.deploy.ip, - "--port", - ca.deploy.port.toString, - "--event-server-ip", - ca.eventServer.ip, - "--event-server-port", - ca.eventServer.port.toString) ++ - (if (ca.accessKey.accessKey != "") { - Seq("--accesskey", ca.accessKey.accessKey) - } else { - Nil - }) ++ - (if (ca.eventServer.enabled) Seq("--feedback") else Nil) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Nil) ++ - (if (ca.common.verbose) Seq("--verbose") else Nil) ++ - ca.deploy.logUrl.map(x => Seq("--log-url", x)).getOrElse(Nil) ++ - ca.deploy.logPrefix.map(x => Seq("--log-prefix", x)).getOrElse(Nil) ++ - Seq("--json-extractor", ca.common.jsonExtractor.toString) - - Runner.runOnSpark("io.prediction.workflow.CreateServer", args, ca, jarFiles) - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala b/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala deleted file mode 100644 index b18690e..0000000 --- a/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala +++ /dev/null @@ -1,212 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools - -import java.io.File -import java.net.URI - -import grizzled.slf4j.Logging -import io.prediction.data.storage.EngineManifest -import io.prediction.tools.console.ConsoleArgs -import io.prediction.workflow.WorkflowUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - -import scala.sys.process._ - -object RunWorkflow extends Logging { - def runWorkflow( - ca: ConsoleArgs, - core: File, - em: EngineManifest, - variantJson: File): Int = { - // Collect and serialize PIO_* environmental variables - val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv => - s"${kv._1}=${kv._2}" - ).mkString(",") - - val sparkHome = ca.common.sparkHome.getOrElse( - sys.env.getOrElse("SPARK_HOME", ".")) - - val hadoopConf = new Configuration - val hdfs = FileSystem.get(hadoopConf) - - val driverClassPathIndex = - ca.common.sparkPassThrough.indexOf("--driver-class-path") - val driverClassPathPrefix = - if (driverClassPathIndex != -1) { - Seq(ca.common.sparkPassThrough(driverClassPathIndex + 1)) - } else { - Seq() - } - val extraClasspaths = - driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths - - val deployModeIndex = - ca.common.sparkPassThrough.indexOf("--deploy-mode") - val deployMode = if (deployModeIndex != -1) { - ca.common.sparkPassThrough(deployModeIndex + 1) - } else { - "client" - } - - val extraFiles = WorkflowUtils.thirdPartyConfFiles - - val mainJar = - if (ca.build.uberJar) { - if (deployMode == "cluster") { - em.files.filter(_.startsWith("hdfs")).head - } else { - em.files.filterNot(_.startsWith("hdfs")).head - } - } else { - if (deployMode == "cluster") { - em.files.filter(_.contains("pio-assembly")).head - } else { - core.getCanonicalPath - } - } - - val workMode = - ca.common.evaluation.map(_ => "Evaluation").getOrElse("Training") - - val engineLocation = Seq( - sys.env("PIO_FS_ENGINESDIR"), - em.id, - em.version) - - if (deployMode == "cluster") { - val dstPath = new Path(engineLocation.mkString(Path.SEPARATOR)) - info("Cluster deploy mode detected. Trying to copy " + - s"${variantJson.getCanonicalPath} to " + - s"${hdfs.makeQualified(dstPath).toString}.") - hdfs.copyFromLocalFile(new Path(variantJson.toURI), dstPath) - } - - val sparkSubmit = - Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) ++ - ca.common.sparkPassThrough ++ - Seq( - "--class", - "io.prediction.workflow.CreateWorkflow", - "--name", - s"PredictionIO $workMode: ${em.id} ${em.version} (${ca.common.batch})") ++ - (if (!ca.build.uberJar) { - Seq("--jars", em.files.mkString(",")) - } else Seq()) ++ - (if (extraFiles.size > 0) { - Seq("--files", extraFiles.mkString(",")) - } else { - Seq() - }) ++ - (if (extraClasspaths.size > 0) { - Seq("--driver-class-path", extraClasspaths.mkString(":")) - } else { - Seq() - }) ++ - (if (ca.common.sparkKryo) { - Seq( - "--conf", - "spark.serializer=org.apache.spark.serializer.KryoSerializer") - } else { - Seq() - }) ++ - Seq( - mainJar, - "--env", - pioEnvVars, - "--engine-id", - em.id, - "--engine-version", - em.version, - "--engine-variant", - if (deployMode == "cluster") { - hdfs.makeQualified(new Path( - (engineLocation :+ variantJson.getName).mkString(Path.SEPARATOR))). - toString - } else { - variantJson.getCanonicalPath - }, - "--verbosity", - ca.common.verbosity.toString) ++ - ca.common.engineFactory.map( - x => Seq("--engine-factory", x)).getOrElse(Seq()) ++ - ca.common.engineParamsKey.map( - x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++ - (if (deployMode == "cluster") Seq("--deploy-mode", "cluster") else Seq()) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ - (if (ca.common.verbose) Seq("--verbose") else Seq()) ++ - (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++ - (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++ - (if (ca.common.stopAfterPrepare) { - Seq("--stop-after-prepare") - } else { - Seq() - }) ++ - ca.common.evaluation.map(x => Seq("--evaluation-class", x)). - getOrElse(Seq()) ++ - // If engineParamsGenerator is specified, it overrides the evaluation. - ca.common.engineParamsGenerator.orElse(ca.common.evaluation) - .map(x => Seq("--engine-params-generator-class", x)) - .getOrElse(Seq()) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ - Seq("--json-extractor", ca.common.jsonExtractor.toString) - - info(s"Submission command: ${sparkSubmit.mkString(" ")}") - Process(sparkSubmit, None, "CLASSPATH" -> "", "SPARK_YARN_USER_ENV" -> pioEnvVars).! - } - - def newRunWorkflow(ca: ConsoleArgs, em: EngineManifest): Int = { - val jarFiles = em.files.map(new URI(_)) - val args = Seq( - "--engine-id", - em.id, - "--engine-version", - em.version, - "--engine-variant", - ca.common.variantJson.toURI.toString, - "--verbosity", - ca.common.verbosity.toString) ++ - ca.common.engineFactory.map( - x => Seq("--engine-factory", x)).getOrElse(Seq()) ++ - ca.common.engineParamsKey.map( - x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ - (if (ca.common.verbose) Seq("--verbose") else Seq()) ++ - (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++ - (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++ - (if (ca.common.stopAfterPrepare) { - Seq("--stop-after-prepare") - } else { - Seq() - }) ++ - ca.common.evaluation.map(x => Seq("--evaluation-class", x)). - getOrElse(Seq()) ++ - // If engineParamsGenerator is specified, it overrides the evaluation. - ca.common.engineParamsGenerator.orElse(ca.common.evaluation) - .map(x => Seq("--engine-params-generator-class", x)) - .getOrElse(Seq()) ++ - (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ - Seq("--json-extractor", ca.common.jsonExtractor.toString) - - Runner.runOnSpark( - "io.prediction.workflow.CreateWorkflow", - args, - ca, - jarFiles) - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/Runner.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/Runner.scala b/tools/src/main/scala/io/prediction/tools/Runner.scala deleted file mode 100644 index 3156660..0000000 --- a/tools/src/main/scala/io/prediction/tools/Runner.scala +++ /dev/null @@ -1,211 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools - -import java.io.File -import java.net.URI - -import grizzled.slf4j.Logging -import io.prediction.tools.console.ConsoleArgs -import io.prediction.workflow.WorkflowUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - -import scala.sys.process._ - -object Runner extends Logging { - def envStringToMap(env: String): Map[String, String] = - env.split(',').flatMap(p => - p.split('=') match { - case Array(k, v) => List(k -> v) - case _ => Nil - } - ).toMap - - def argumentValue(arguments: Seq[String], argumentName: String): Option[String] = { - val argumentIndex = arguments.indexOf(argumentName) - try { - arguments(argumentIndex) // just to make it error out if index is -1 - Some(arguments(argumentIndex + 1)) - } catch { - case e: IndexOutOfBoundsException => None - } - } - - def handleScratchFile( - fileSystem: Option[FileSystem], - uri: Option[URI], - localFile: File): String = { - val localFilePath = localFile.getCanonicalPath - (fileSystem, uri) match { - case (Some(fs), Some(u)) => - val dest = fs.makeQualified(Path.mergePaths( - new Path(u), - new Path(localFilePath))) - info(s"Copying $localFile to ${dest.toString}") - fs.copyFromLocalFile(new Path(localFilePath), dest) - dest.toUri.toString - case _ => localFile.toURI.toString - } - } - - def cleanup(fs: Option[FileSystem], uri: Option[URI]): Unit = { - (fs, uri) match { - case (Some(f), Some(u)) => - f.close() - case _ => Unit - } - } - - def detectFilePaths( - fileSystem: Option[FileSystem], - uri: Option[URI], - args: Seq[String]): Seq[String] = { - args map { arg => - val f = try { - new File(new URI(arg)) - } catch { - case e: Throwable => new File(arg) - } - if (f.exists()) { - handleScratchFile(fileSystem, uri, f) - } else { - arg - } - } - } - - def runOnSpark( - className: String, - classArgs: Seq[String], - ca: ConsoleArgs, - extraJars: Seq[URI]): Int = { - // Return error for unsupported cases - val deployMode = - argumentValue(ca.common.sparkPassThrough, "--deploy-mode").getOrElse("client") - val master = - argumentValue(ca.common.sparkPassThrough, "--master").getOrElse("local") - - (ca.common.scratchUri, deployMode, master) match { - case (Some(u), "client", m) if m != "yarn-cluster" => - error("--scratch-uri cannot be set when deploy mode is client") - return 1 - case (_, "cluster", m) if m.startsWith("spark://") => - error("Using cluster deploy mode with Spark standalone cluster is not supported") - return 1 - case _ => Unit - } - - // Initialize HDFS API for scratch URI - val fs = ca.common.scratchUri map { uri => - FileSystem.get(uri, new Configuration()) - } - - // Collect and serialize PIO_* environmental variables - val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv => - s"${kv._1}=${kv._2}" - ).mkString(",") - - // Location of Spark - val sparkHome = ca.common.sparkHome.getOrElse( - sys.env.getOrElse("SPARK_HOME", ".")) - - // Local path to PredictionIO assembly JAR - val mainJar = handleScratchFile( - fs, - ca.common.scratchUri, - console.Console.coreAssembly(ca.common.pioHome.get)) - - // Extra JARs that are needed by the driver - val driverClassPathPrefix = - argumentValue(ca.common.sparkPassThrough, "--driver-class-path") map { v => - Seq(v) - } getOrElse { - Nil - } - - val extraClasspaths = - driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths - - // Extra files that are needed to be passed to --files - val extraFiles = WorkflowUtils.thirdPartyConfFiles map { f => - handleScratchFile(fs, ca.common.scratchUri, new File(f)) - } - - val deployedJars = extraJars map { j => - handleScratchFile(fs, ca.common.scratchUri, new File(j)) - } - - val sparkSubmitCommand = - Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) - - val sparkSubmitJars = if (extraJars.nonEmpty) { - Seq("--jars", deployedJars.map(_.toString).mkString(",")) - } else { - Nil - } - - val sparkSubmitFiles = if (extraFiles.nonEmpty) { - Seq("--files", extraFiles.mkString(",")) - } else { - Nil - } - - val sparkSubmitExtraClasspaths = if (extraClasspaths.nonEmpty) { - Seq("--driver-class-path", extraClasspaths.mkString(":")) - } else { - Nil - } - - val sparkSubmitKryo = if (ca.common.sparkKryo) { - Seq( - "--conf", - "spark.serializer=org.apache.spark.serializer.KryoSerializer") - } else { - Nil - } - - val verbose = if (ca.common.verbose) Seq("--verbose") else Nil - - val sparkSubmit = Seq( - sparkSubmitCommand, - ca.common.sparkPassThrough, - Seq("--class", className), - sparkSubmitJars, - sparkSubmitFiles, - sparkSubmitExtraClasspaths, - sparkSubmitKryo, - Seq(mainJar), - detectFilePaths(fs, ca.common.scratchUri, classArgs), - Seq("--env", pioEnvVars), - verbose).flatten.filter(_ != "") - info(s"Submission command: ${sparkSubmit.mkString(" ")}") - val proc = Process( - sparkSubmit, - None, - "CLASSPATH" -> "", - "SPARK_YARN_USER_ENV" -> pioEnvVars).run() - Runtime.getRuntime.addShutdownHook(new Thread(new Runnable { - def run(): Unit = { - cleanup(fs, ca.common.scratchUri) - proc.destroy() - } - })) - cleanup(fs, ca.common.scratchUri) - proc.exitValue() - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala b/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala deleted file mode 100644 index c5ec913..0000000 --- a/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala +++ /dev/null @@ -1,156 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools.admin - -import akka.actor.{Actor, ActorSystem, Props} -import akka.event.Logging -import akka.io.IO -import akka.util.Timeout -import io.prediction.data.api.StartServer -import io.prediction.data.storage.Storage -import org.json4s.{Formats, DefaultFormats} - -import java.util.concurrent.TimeUnit - -import spray.can.Http -import spray.http.{MediaTypes, StatusCodes} -import spray.httpx.Json4sSupport -import spray.routing._ - -import scala.concurrent.ExecutionContext - -class AdminServiceActor(val commandClient: CommandClient) - extends HttpServiceActor { - - object Json4sProtocol extends Json4sSupport { - implicit def json4sFormats: Formats = DefaultFormats - } - - import Json4sProtocol._ - - val log = Logging(context.system, this) - - // we use the enclosing ActorContext's or ActorSystem's dispatcher for our - // Futures - implicit def executionContext: ExecutionContext = actorRefFactory.dispatcher - implicit val timeout: Timeout = Timeout(5, TimeUnit.SECONDS) - - // for better message response - val rejectionHandler = RejectionHandler { - case MalformedRequestContentRejection(msg, _) :: _ => - complete(StatusCodes.BadRequest, Map("message" -> msg)) - case MissingQueryParamRejection(msg) :: _ => - complete(StatusCodes.NotFound, - Map("message" -> s"missing required query parameter ${msg}.")) - case AuthenticationFailedRejection(cause, challengeHeaders) :: _ => - complete(StatusCodes.Unauthorized, challengeHeaders, - Map("message" -> s"Invalid accessKey.")) - } - - val jsonPath = """(.+)\.json$""".r - - val route: Route = - pathSingleSlash { - get { - respondWithMediaType(MediaTypes.`application/json`) { - complete(Map("status" -> "alive")) - } - } - } ~ - path("cmd" / "app" / Segment / "data") { - appName => { - delete { - respondWithMediaType(MediaTypes.`application/json`) { - complete(commandClient.futureAppDataDelete(appName)) - } - } - } - } ~ - path("cmd" / "app" / Segment) { - appName => { - delete { - respondWithMediaType(MediaTypes.`application/json`) { - complete(commandClient.futureAppDelete(appName)) - } - } - } - } ~ - path("cmd" / "app") { - get { - respondWithMediaType(MediaTypes.`application/json`) { - complete(commandClient.futureAppList()) - } - } ~ - post { - entity(as[AppRequest]) { - appArgs => respondWithMediaType(MediaTypes.`application/json`) { - complete(commandClient.futureAppNew(appArgs)) - } - } - } - } - def receive: Actor.Receive = runRoute(route) -} - -class AdminServerActor(val commandClient: CommandClient) extends Actor { - val log = Logging(context.system, this) - val child = context.actorOf( - Props(classOf[AdminServiceActor], commandClient), - "AdminServiceActor") - - implicit val system = context.system - - def receive: PartialFunction[Any, Unit] = { - case StartServer(host, portNum) => { - IO(Http) ! Http.Bind(child, interface = host, port = portNum) - - } - case m: Http.Bound => log.info("Bound received. AdminServer is ready.") - case m: Http.CommandFailed => log.error("Command failed.") - case _ => log.error("Unknown message.") - } -} - -case class AdminServerConfig( - ip: String = "localhost", - port: Int = 7071 -) - -object AdminServer { - def createAdminServer(config: AdminServerConfig): Unit = { - implicit val system = ActorSystem("AdminServerSystem") - - val commandClient = new CommandClient( - appClient = Storage.getMetaDataApps, - accessKeyClient = Storage.getMetaDataAccessKeys, - eventClient = Storage.getLEvents() - ) - - val serverActor = system.actorOf( - Props(classOf[AdminServerActor], commandClient), - "AdminServerActor") - serverActor ! StartServer(config.ip, config.port) - system.awaitTermination - } -} - -object AdminRun { - def main (args: Array[String]) { - AdminServer.createAdminServer(AdminServerConfig( - ip = "localhost", - port = 7071)) - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala b/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala deleted file mode 100644 index 924b6f0..0000000 --- a/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala +++ /dev/null @@ -1,160 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools.admin - -import io.prediction.data.storage._ - -import scala.concurrent.{ExecutionContext, Future} - -abstract class BaseResponse() - -case class GeneralResponse( - status: Int = 0, - message: String = "" -) extends BaseResponse() - -case class AppRequest( - id: Int = 0, - name: String = "", - description: String = "" -) - -case class TrainRequest( - enginePath: String = "" -) -case class AppResponse( - id: Int = 0, - name: String = "", - keys: Seq[AccessKey] -) extends BaseResponse() - -case class AppNewResponse( - status: Int = 0, - message: String = "", - id: Int = 0, - name: String = "", - key: String -) extends BaseResponse() - -case class AppListResponse( - status: Int = 0, - message: String = "", - apps: Seq[AppResponse] -) extends BaseResponse() - -class CommandClient( - val appClient: Apps, - val accessKeyClient: AccessKeys, - val eventClient: LEvents -) { - - def futureAppNew(req: AppRequest)(implicit ec: ExecutionContext): Future[BaseResponse] = Future { - val response = appClient.getByName(req.name) map { app => - GeneralResponse(0, s"App ${req.name} already exists. Aborting.") - } getOrElse { - appClient.get(req.id) map { - app2 => - GeneralResponse(0, - s"App ID ${app2.id} already exists and maps to the app '${app2.name}'. " + - "Aborting.") - } getOrElse { - val appid = appClient.insert(App( - id = Option(req.id).getOrElse(0), - name = req.name, - description = Option(req.description))) - appid map { id => - val dbInit = eventClient.init(id) - val r = if (dbInit) { - val accessKey = AccessKey( - key = "", - appid = id, - events = Seq()) - val accessKey2 = accessKeyClient.insert(AccessKey( - key = "", - appid = id, - events = Seq())) - accessKey2 map { k => - new AppNewResponse(1,"App created successfully.",id, req.name, k) - } getOrElse { - GeneralResponse(0, s"Unable to create new access key.") - } - } else { - GeneralResponse(0, s"Unable to initialize Event Store for this app ID: ${id}.") - } - r - } getOrElse { - GeneralResponse(0, s"Unable to create new app.") - } - } - } - response - } - - def futureAppList()(implicit ec: ExecutionContext): Future[AppListResponse] = Future { - val apps = appClient.getAll().sortBy(_.name) - val appsRes = apps.map { - app => { - new AppResponse(app.id, app.name, accessKeyClient.getByAppid(app.id)) - } - } - new AppListResponse(1, "Successful retrieved app list.", appsRes) - } - - def futureAppDataDelete(appName: String) - (implicit ec: ExecutionContext): Future[GeneralResponse] = Future { - val response = appClient.getByName(appName) map { app => - val data = if (eventClient.remove(app.id)) { - GeneralResponse(1, s"Removed Event Store for this app ID: ${app.id}") - } else { - GeneralResponse(0, s"Error removing Event Store for this app.") - } - - val dbInit = eventClient.init(app.id) - val data2 = if (dbInit) { - GeneralResponse(1, s"Initialized Event Store for this app ID: ${app.id}.") - } else { - GeneralResponse(0, s"Unable to initialize Event Store for this appId:" + - s" ${app.id}.") - } - GeneralResponse(data.status * data2.status, data.message + data2.message) - } getOrElse { - GeneralResponse(0, s"App ${appName} does not exist.") - } - response - } - - def futureAppDelete(appName: String) - (implicit ec: ExecutionContext): Future[GeneralResponse] = Future { - - val response = appClient.getByName(appName) map { app => - val data = if (eventClient.remove(app.id)) { - Storage.getMetaDataApps.delete(app.id) - GeneralResponse(1, s"App successfully deleted") - } else { - GeneralResponse(0, s"Error removing Event Store for app ${app.name}."); - } - data - } getOrElse { - GeneralResponse(0, s"App ${appName} does not exist.") - } - response - } - - def futureTrain(req: TrainRequest) - (implicit ec: ExecutionContext): Future[GeneralResponse] = Future { - null - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/README.md ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/admin/README.md b/tools/src/main/scala/io/prediction/tools/admin/README.md deleted file mode 100644 index 475a3de..0000000 --- a/tools/src/main/scala/io/prediction/tools/admin/README.md +++ /dev/null @@ -1,161 +0,0 @@ -## Admin API (under development) - -### Start Admin HTTP Server without bin/pio (for development) - -NOTE: elasticsearch and hbase should be running first. - -``` -$ sbt/sbt "tools/compile" -$ set -a -$ source conf/pio-env.sh -$ set +a -$ sbt/sbt "tools/run-main io.prediction.tools.admin.AdminRun" -``` - -### Unit test (Very minimal) - -``` -$ set -a -$ source conf/pio-env.sh -$ set +a -$ sbt/sbt "tools/test-only io.prediction.tools.admin.AdminAPISpec" -``` - -### Start with pio command adminserver - -``` -$ pio adminserver -``` - -Admin Server url defaults to `http://localhost:7071` - -The host and port can be specified by using the 'ip' and 'port' parameters - -``` -$ pio adminserver --ip 127.0.0.1 --port 7080 -``` - -### Current Supported Commands - -#### Check status - -``` -$ curl -i http://localhost:7071/ - -{"status":"alive"} -``` - -#### Get list of apps - -``` -$ curl -i -X GET http://localhost:7071/cmd/app - -{"status":1,"message":"Successful retrieved app list.","apps":[{"id":12,"name":"scratch","keys":[{"key":"gtPgVMIr3uthus1QJWFBcIjNf6d1SNuhaOWQAgdLbOBP1eRWMNIJWl6SkHgI1OoN","appid":12,"events":[]}]},{"id":17,"name":"test-ecommercerec","keys":[{"key":"zPkr6sBwQoBwBjVHK2hsF9u26L38ARSe19QzkdYentuomCtYSuH0vXP5fq7advo4","appid":17,"events":[]}]}]} -``` - -#### Create a new app - -``` -$ curl -i -X POST http://localhost:7071/cmd/app \ --H "Content-Type: application/json" \ --d '{ "name" : "my_new_app" }' - -{"status":1,"message":"App created successfully.","id":19,"name":"my_new_app","keys":[{"key":"","appid":19,"events":[]}]} -``` - -#### Delete data of app - -``` -$ curl -i -X DELETE http://localhost:7071/cmd/app/my_new_app/data -``` - -#### Delete app - -``` -$ curl -i -X DELETE http://localhost:7071/cmd/app/my_new_app - -{"status":1,"message":"App successfully deleted"} -``` - - -## API Doc (To be updated) - -### app list: -GET http://localhost:7071/cmd/app - -OK Response: -{ - âstatusâ: <STATUS>, - âmessageâ: <MESSAGE>, - âappsâ : [ - { âname': â<APP_NAME>â, - âid': <APP_ID>, - âaccessKey' : â<ACCESS_KEY>â }, - { âname': â<APP_NAME>â, - âid': <APP_ID>, - âaccessKey' : â<ACCESS_KEY>â }, ... ] -} - -Error Response: -{âstatusâ: <STATUS>, âmessageâ : â<MESSAGE>â} - -### app new -POST http://localhost:7071/cmd/app -Request Body: -{ nameâ: â<APP_NAME>â, // required - âidâ: <APP_ID>, // optional - âdescriptionâ: â<DESCRIPTION>â } // optional - -OK Response: -{ âstatusâ: <STATUS>, - âmessageâ: <MESSAGE>, - âappâ : { - ânameâ: â<APP_NAME>â, - âidâ: <APP_ID>, - âaccessKeyâ : â<ACCESS_KEY>â } -} - -Error Response: -{ âstatusâ: <STATUS>, âmessageâ : â<MESSAGE>â} - -### app delete -DELETE http://localhost:7071/cmd/app/{appName} - -OK Response: -{ "status": <STATUS>, "message" : â<MESSAGE>â} - -Error Response: -{ âstatusâ: <STATUS>, âmessageâ : â<MESSAGE>â} - -### app data-delete -DELETE http://localhost:7071/cmd/app/{appName}/data - -OK Response: -{ "status": <STATUS>, "message" : â<MESSAGE>â} - -Error Response: -{ âstatusâ: <STATUS>, âmessageâ : â<MESSAGE>â } - - -### train TBD - -#### Training request: -POST http://localhost:7071/cmd/train -Request body: TBD - -OK Response: TBD - -Error Response: TBD - -#### Get training status: -GET http://localhost:7071/cmd/train/{engineInstanceId} - -OK Response: TBD -INIT -TRAINING -DONE -ERROR - -Error Response: TBD - -### deploy TBD http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala b/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala deleted file mode 100644 index 85955e8..0000000 --- a/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala +++ /dev/null @@ -1,83 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.tools.console - -import io.prediction.data.storage - -import grizzled.slf4j.Logging - -case class AccessKeyArgs( - accessKey: String = "", - events: Seq[String] = Seq()) - -object AccessKey extends Logging { - def create(ca: ConsoleArgs): Int = { - val apps = storage.Storage.getMetaDataApps - apps.getByName(ca.app.name) map { app => - val accessKeys = storage.Storage.getMetaDataAccessKeys - val accessKey = accessKeys.insert(storage.AccessKey( - key = ca.accessKey.accessKey, - appid = app.id, - events = ca.accessKey.events)) - accessKey map { k => - info(s"Created new access key: ${k}") - 0 - } getOrElse { - error(s"Unable to create new access key.") - 1 - } - } getOrElse { - error(s"App ${ca.app.name} does not exist. Aborting.") - 1 - } - } - - def list(ca: ConsoleArgs): Int = { - val keys = - if (ca.app.name == "") { - storage.Storage.getMetaDataAccessKeys.getAll - } else { - val apps = storage.Storage.getMetaDataApps - apps.getByName(ca.app.name) map { app => - storage.Storage.getMetaDataAccessKeys.getByAppid(app.id) - } getOrElse { - error(s"App ${ca.app.name} does not exist. Aborting.") - return 1 - } - } - val title = "Access Key(s)" - info(f"$title%64s | App ID | Allowed Event(s)") - keys.sortBy(k => k.appid) foreach { k => - val events = - if (k.events.size > 0) k.events.sorted.mkString(",") else "(all)" - info(f"${k.key}%64s | ${k.appid}%6d | $events%s") - } - info(s"Finished listing ${keys.size} access key(s).") - 0 - } - - def delete(ca: ConsoleArgs): Int = { - try { - storage.Storage.getMetaDataAccessKeys.delete(ca.accessKey.accessKey) - info(s"Deleted access key ${ca.accessKey.accessKey}.") - 0 - } catch { - case e: Exception => - error(s"Error deleting access key ${ca.accessKey.accessKey}.", e) - 1 - } - } -}
