http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/storage/PEventsSpec.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/storage/PEventsSpec.scala b/data/src/test/scala/org/apache/predictionio/data/storage/PEventsSpec.scala new file mode 100644 index 0000000..93cbe6e --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/storage/PEventsSpec.scala @@ -0,0 +1,210 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.storage + +import org.specs2._ +import org.specs2.specification.Step + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD + +class PEventsSpec extends Specification with TestEvents { + + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + val sc = new SparkContext("local[4]", "PEventAggregatorSpec test") + + val appId = 1 + val channelId = 6 + val dbName = "test_pio_storage_events_" + hashCode + + def hbLocal = Storage.getDataObject[LEvents]( + StorageTestUtils.hbaseSourceName, + dbName + ) + + def hbPar = Storage.getDataObject[PEvents]( + StorageTestUtils.hbaseSourceName, + dbName + ) + + def jdbcLocal = Storage.getDataObject[LEvents]( + StorageTestUtils.jdbcSourceName, + dbName + ) + + def jdbcPar = Storage.getDataObject[PEvents]( + StorageTestUtils.jdbcSourceName, + dbName + ) + + def stopSpark = { + sc.stop() + } + + def is = s2""" + + PredictionIO Storage PEvents Specification + + PEvents can be implemented by: + - HBPEvents ${hbPEvents} + - JDBCPEvents ${jdbcPEvents} + - (stop Spark) ${Step(sc.stop())} + + """ + + def hbPEvents = sequential ^ s2""" + + HBPEvents should + - behave like any PEvents implementation ${events(hbLocal, hbPar)} + - (table cleanup) ${Step(StorageTestUtils.dropHBaseNamespace(dbName))} + + """ + + def jdbcPEvents = sequential ^ s2""" + + JDBCPEvents should + - behave like any PEvents implementation ${events(jdbcLocal, jdbcPar)} + - (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_$appId"))} + - (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_${appId}_$channelId"))} + + """ + + def events(localEventClient: LEvents, parEventClient: PEvents) = sequential ^ s2""" + + - (init test) ${initTest(localEventClient)} + - (insert test events) ${insertTestEvents(localEventClient)} + find in default ${find(parEventClient)} + find in channel ${findChannel(parEventClient)} + aggregate user properties in default ${aggregateUserProperties(parEventClient)} + aggregate user properties in channel ${aggregateUserPropertiesChannel(parEventClient)} + write to default ${write(parEventClient)} + write to channel ${writeChannel(parEventClient)} + + """ + + /* setup */ + + // events from TestEvents trait + val listOfEvents = List(u1e5, u2e2, u1e3, u1e1, u2e3, u2e1, u1e4, u1e2, r1, r2) + val listOfEventsChannel = List(u3e1, u3e2, u3e3, r3, r4) + + def initTest(localEventClient: LEvents) = { + localEventClient.init(appId) + localEventClient.init(appId, Some(channelId)) + } + + def insertTestEvents(localEventClient: LEvents) = { + listOfEvents.map( localEventClient.insert(_, appId) ) + // insert to channel + listOfEventsChannel.map( localEventClient.insert(_, appId, Some(channelId)) ) + success + } + + /* following are tests */ + + def find(parEventClient: PEvents) = { + val resultRDD: RDD[Event] = parEventClient.find( + appId = appId + )(sc) + + val results = resultRDD.collect.toList + .map {_.copy(eventId = None)} // ignore eventId + + results must containTheSameElementsAs(listOfEvents) + } + + def findChannel(parEventClient: PEvents) = { + val resultRDD: RDD[Event] = parEventClient.find( + appId = appId, + channelId = Some(channelId) + )(sc) + + val results = resultRDD.collect.toList + .map {_.copy(eventId = None)} // ignore eventId + + results must containTheSameElementsAs(listOfEventsChannel) + } + + def aggregateUserProperties(parEventClient: PEvents) = { + val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties( + appId = appId, + entityType = "user" + )(sc) + val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap + + val expected = Map( + "u1" -> PropertyMap(u1, u1BaseTime, u1LastTime), + "u2" -> PropertyMap(u2, u2BaseTime, u2LastTime) + ) + + result must beEqualTo(expected) + } + + def aggregateUserPropertiesChannel(parEventClient: PEvents) = { + val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties( + appId = appId, + channelId = Some(channelId), + entityType = "user" + )(sc) + val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap + + val expected = Map( + "u3" -> PropertyMap(u3, u3BaseTime, u3LastTime) + ) + + result must beEqualTo(expected) + } + + def write(parEventClient: PEvents) = { + val written = List(r5, r6) + val writtenRDD = sc.parallelize(written) + parEventClient.write(writtenRDD, appId)(sc) + + // read back + val resultRDD = parEventClient.find( + appId = appId + )(sc) + + val results = resultRDD.collect.toList + .map { _.copy(eventId = None)} // ignore eventId + + val expected = listOfEvents ++ written + + results must containTheSameElementsAs(expected) + } + + def writeChannel(parEventClient: PEvents) = { + val written = List(r1, r5, r6) + val writtenRDD = sc.parallelize(written) + parEventClient.write(writtenRDD, appId, Some(channelId))(sc) + + // read back + val resultRDD = parEventClient.find( + appId = appId, + channelId = Some(channelId) + )(sc) + + val results = resultRDD.collect.toList + .map { _.copy(eventId = None)} // ignore eventId + + val expected = listOfEventsChannel ++ written + + results must containTheSameElementsAs(expected) + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/storage/StorageTestUtils.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/storage/StorageTestUtils.scala b/data/src/test/scala/org/apache/predictionio/data/storage/StorageTestUtils.scala new file mode 100644 index 0000000..6068f4c --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/storage/StorageTestUtils.scala @@ -0,0 +1,42 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.storage + +import org.apache.predictionio.data.storage.hbase.HBLEvents +import scalikejdbc._ + +object StorageTestUtils { + val hbaseSourceName = "HBASE" + val jdbcSourceName = "PGSQL" + + def dropHBaseNamespace(namespace: String): Unit = { + val eventDb = Storage.getDataObject[LEvents](hbaseSourceName, namespace) + .asInstanceOf[HBLEvents] + val admin = eventDb.client.admin + val tableNames = admin.listTableNamesByNamespace(namespace) + tableNames.foreach { name => + admin.disableTable(name) + admin.deleteTable(name) + } + + //Only empty namespaces (no tables) can be removed. + admin.deleteNamespace(namespace) + } + + def dropJDBCTable(table: String): Unit = DB autoCommit { implicit s => + SQL(s"drop table $table").execute().apply() + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/storage/TestEvents.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/storage/TestEvents.scala b/data/src/test/scala/org/apache/predictionio/data/storage/TestEvents.scala new file mode 100644 index 0000000..f1c327b --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/storage/TestEvents.scala @@ -0,0 +1,263 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.storage + +import org.joda.time.DateTime +import org.joda.time.DateTimeZone + +trait TestEvents { + + val u1BaseTime = new DateTime(654321) + val u2BaseTime = new DateTime(6543210) + val u3BaseTime = new DateTime(6543410) + + // u1 events + val u1e1 = Event( + event = "$set", + entityType = "user", + entityId = "u1", + properties = DataMap( + """{ + "a" : 1, + "b" : "value2", + "d" : [1, 2, 3], + }"""), + eventTime = u1BaseTime + ) + + val u1e2 = u1e1.copy( + event = "$set", + properties = DataMap("""{"a" : 2}"""), + eventTime = u1BaseTime.plusDays(1) + ) + + val u1e3 = u1e1.copy( + event = "$set", + properties = DataMap("""{"b" : "value4"}"""), + eventTime = u1BaseTime.plusDays(2) + ) + + val u1e4 = u1e1.copy( + event = "$unset", + properties = DataMap("""{"b" : null}"""), + eventTime = u1BaseTime.plusDays(3) + ) + + val u1e5 = u1e1.copy( + event = "$set", + properties = DataMap("""{"e" : "new"}"""), + eventTime = u1BaseTime.plusDays(4) + ) + + val u1LastTime = u1BaseTime.plusDays(4) + val u1 = """{"a": 2, "d": [1, 2, 3], "e": "new"}""" + + // delete event for u1 + val u1ed = u1e1.copy( + event = "$delete", + properties = DataMap(), + eventTime = u1BaseTime.plusDays(5) + ) + + // u2 events + val u2e1 = Event( + event = "$set", + entityType = "user", + entityId = "u2", + properties = DataMap( + """{ + "a" : 21, + "b" : "value12", + "d" : [7, 5, 6], + }"""), + eventTime = u2BaseTime + ) + + val u2e2 = u2e1.copy( + event = "$unset", + properties = DataMap("""{"a" : null}"""), + eventTime = u2BaseTime.plusDays(1) + ) + + val u2e3 = u2e1.copy( + event = "$set", + properties = DataMap("""{"b" : "value9", "g": "new11"}"""), + eventTime = u2BaseTime.plusDays(2) + ) + + val u2LastTime = u2BaseTime.plusDays(2) + val u2 = """{"b": "value9", "d": [7, 5, 6], "g": "new11"}""" + + // u3 events + val u3e1 = Event( + event = "$set", + entityType = "user", + entityId = "u3", + properties = DataMap( + """{ + "a" : 22, + "b" : "value13", + "d" : [5, 6, 1], + }"""), + eventTime = u3BaseTime + ) + + val u3e2 = u3e1.copy( + event = "$unset", + properties = DataMap("""{"a" : null}"""), + eventTime = u3BaseTime.plusDays(1) + ) + + val u3e3 = u3e1.copy( + event = "$set", + properties = DataMap("""{"b" : "value10", "f": "new12", "d" : [1, 3, 2]}"""), + eventTime = u3BaseTime.plusDays(2) + ) + + val u3LastTime = u3BaseTime.plusDays(2) + val u3 = """{"b": "value10", "d": [1, 3, 2], "f": "new12"}""" + + // some random events + val r1 = Event( + event = "my_event", + entityType = "my_entity_type", + entityId = "my_entity_id", + targetEntityType = Some("my_target_entity_type"), + targetEntityId = Some("my_target_entity_id"), + properties = DataMap( + """{ + "prop1" : 1, + "prop2" : "value2", + "prop3" : [1, 2, 3], + "prop4" : true, + "prop5" : ["a", "b", "c"], + "prop6" : 4.56 + }""" + ), + eventTime = DateTime.now, + prId = Some("my_prid") + ) + val r2 = Event( + event = "my_event2", + entityType = "my_entity_type2", + entityId = "my_entity_id2" + ) + val r3 = Event( + event = "my_event3", + entityType = "my_entity_type", + entityId = "my_entity_id", + targetEntityType = Some("my_target_entity_type"), + targetEntityId = Some("my_target_entity_id"), + properties = DataMap( + """{ + "propA" : 1.2345, + "propB" : "valueB", + }""" + ), + prId = Some("my_prid") + ) + val r4 = Event( + event = "my_event4", + entityType = "my_entity_type4", + entityId = "my_entity_id4", + targetEntityType = Some("my_target_entity_type4"), + targetEntityId = Some("my_target_entity_id4"), + properties = DataMap( + """{ + "prop1" : 1, + "prop2" : "value2", + "prop3" : [1, 2, 3], + "prop4" : true, + "prop5" : ["a", "b", "c"], + "prop6" : 4.56 + }"""), + eventTime = DateTime.now + ) + val r5 = Event( + event = "my_event5", + entityType = "my_entity_type5", + entityId = "my_entity_id5", + targetEntityType = Some("my_target_entity_type5"), + targetEntityId = Some("my_target_entity_id5"), + properties = DataMap( + """{ + "prop1" : 1, + "prop2" : "value2", + "prop3" : [1, 2, 3], + "prop4" : true, + "prop5" : ["a", "b", "c"], + "prop6" : 4.56 + }""" + ), + eventTime = DateTime.now + ) + val r6 = Event( + event = "my_event6", + entityType = "my_entity_type6", + entityId = "my_entity_id6", + targetEntityType = Some("my_target_entity_type6"), + targetEntityId = Some("my_target_entity_id6"), + properties = DataMap( + """{ + "prop1" : 6, + "prop2" : "value2", + "prop3" : [6, 7, 8], + "prop4" : true, + "prop5" : ["a", "b", "c"], + "prop6" : 4.56 + }""" + ), + eventTime = DateTime.now + ) + + // timezone + val tz1 = Event( + event = "my_event", + entityType = "my_entity_type", + entityId = "my_entity_id0", + targetEntityType = Some("my_target_entity_type"), + targetEntityId = Some("my_target_entity_id"), + properties = DataMap( + """{ + "prop1" : 1, + "prop2" : "value2", + "prop3" : [1, 2, 3], + "prop4" : true, + "prop5" : ["a", "b", "c"], + "prop6" : 4.56 + }""" + ), + eventTime = new DateTime(12345678, DateTimeZone.forID("-08:00")), + prId = Some("my_prid") + ) + + val tz2 = Event( + event = "my_event", + entityType = "my_entity_type", + entityId = "my_entity_id1", + eventTime = new DateTime(12345678, DateTimeZone.forID("+02:00")), + prId = Some("my_prid") + ) + + val tz3 = Event( + event = "my_event", + entityType = "my_entity_type", + entityId = "my_entity_id2", + eventTime = new DateTime(12345678, DateTimeZone.forID("+08:00")), + prId = Some("my_prid") + ) + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/webhooks/ConnectorTestUtil.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/webhooks/ConnectorTestUtil.scala b/data/src/test/scala/org/apache/predictionio/data/webhooks/ConnectorTestUtil.scala new file mode 100644 index 0000000..0998c52 --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/webhooks/ConnectorTestUtil.scala @@ -0,0 +1,47 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.webhooks + +import org.specs2.execute.Result +import org.specs2.mutable._ + +import org.json4s.JObject +import org.json4s.DefaultFormats +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization.write + +/** TestUtil for JsonConnector */ +trait ConnectorTestUtil extends Specification { + + implicit val formats = DefaultFormats + + def check(connector: JsonConnector, original: String, event: String): Result = { + val originalJson = parse(original).asInstanceOf[JObject] + val eventJson = parse(event).asInstanceOf[JObject] + // write and parse back to discard any JNothing field + val result = parse(write(connector.toEventJson(originalJson))).asInstanceOf[JObject] + result.obj must containTheSameElementsAs(eventJson.obj) + } + + def check(connector: FormConnector, original: Map[String, String], event: String) = { + + val eventJson = parse(event).asInstanceOf[JObject] + // write and parse back to discard any JNothing field + val result = parse(write(connector.toEventJson(original))).asInstanceOf[JObject] + + result.obj must containTheSameElementsAs(eventJson.obj) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/webhooks/exampleform/ExampleFormConnectorSpec.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/webhooks/exampleform/ExampleFormConnectorSpec.scala b/data/src/test/scala/org/apache/predictionio/data/webhooks/exampleform/ExampleFormConnectorSpec.scala new file mode 100644 index 0000000..d99e2ca --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/webhooks/exampleform/ExampleFormConnectorSpec.scala @@ -0,0 +1,164 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.webhooks.exampleform + +import org.apache.predictionio.data.webhooks.ConnectorTestUtil + +import org.specs2.mutable._ + +/** Test the ExampleFormConnector */ +class ExampleFormConnectorSpec extends Specification with ConnectorTestUtil { + + "ExampleFormConnector" should { + + "convert userAction to Event JSON" in { + // webhooks input + val userAction = Map( + "type" -> "userAction", + "userId" -> "as34smg4", + "event" -> "do_something", + "context[ip]" -> "24.5.68.47", // optional + "context[prop1]" -> "2.345", // optional + "context[prop2]" -> "value1", // optional + "anotherProperty1" -> "100", + "anotherProperty2"-> "optional1", // optional + "timestamp" -> "2015-01-02T00:30:12.984Z" + ) + + // expected converted Event JSON + val expected = """ + { + "event": "do_something", + "entityType": "user", + "entityId": "as34smg4", + "properties": { + "context": { + "ip": "24.5.68.47", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherProperty1": 100, + "anotherProperty2": "optional1" + } + "eventTime": "2015-01-02T00:30:12.984Z" + } + """ + + check(ExampleFormConnector, userAction, expected) + } + + "convert userAction without optional fields to Event JSON" in { + // webhooks input + val userAction = Map( + "type" -> "userAction", + "userId" -> "as34smg4", + "event" -> "do_something", + "anotherProperty1" -> "100", + "timestamp" -> "2015-01-02T00:30:12.984Z" + ) + + // expected converted Event JSON + val expected = """ + { + "event": "do_something", + "entityType": "user", + "entityId": "as34smg4", + "properties": { + "anotherProperty1": 100, + } + "eventTime": "2015-01-02T00:30:12.984Z" + } + """ + + check(ExampleFormConnector, userAction, expected) + } + + "convert userActionItem to Event JSON" in { + // webhooks input + val userActionItem = Map( + "type" -> "userActionItem", + "userId" -> "as34smg4", + "event" -> "do_something_on", + "itemId" -> "kfjd312bc", + "context[ip]" -> "1.23.4.56", + "context[prop1]" -> "2.345", + "context[prop2]" -> "value1", + "anotherPropertyA" -> "4.567", // optional + "anotherPropertyB" -> "false", // optional + "timestamp" -> "2015-01-15T04:20:23.567Z" + ) + + // expected converted Event JSON + val expected = """ + { + "event": "do_something_on", + "entityType": "user", + "entityId": "as34smg4", + "targetEntityType": "item", + "targetEntityId": "kfjd312bc" + "properties": { + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherPropertyA": 4.567 + "anotherPropertyB": false + } + "eventTime": "2015-01-15T04:20:23.567Z" + } + """ + + check(ExampleFormConnector, userActionItem, expected) + } + + "convert userActionItem without optional fields to Event JSON" in { + // webhooks input + val userActionItem = Map( + "type" -> "userActionItem", + "userId" -> "as34smg4", + "event" -> "do_something_on", + "itemId" -> "kfjd312bc", + "context[ip]" -> "1.23.4.56", + "context[prop1]" -> "2.345", + "context[prop2]" -> "value1", + "timestamp" -> "2015-01-15T04:20:23.567Z" + ) + + // expected converted Event JSON + val expected = """ + { + "event": "do_something_on", + "entityType": "user", + "entityId": "as34smg4", + "targetEntityType": "item", + "targetEntityId": "kfjd312bc" + "properties": { + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + } + } + "eventTime": "2015-01-15T04:20:23.567Z" + } + """ + + check(ExampleFormConnector, userActionItem, expected) + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/webhooks/examplejson/ExampleJsonConnectorSpec.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/webhooks/examplejson/ExampleJsonConnectorSpec.scala b/data/src/test/scala/org/apache/predictionio/data/webhooks/examplejson/ExampleJsonConnectorSpec.scala new file mode 100644 index 0000000..069d52e --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/webhooks/examplejson/ExampleJsonConnectorSpec.scala @@ -0,0 +1,179 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.webhooks.examplejson + +import org.apache.predictionio.data.webhooks.ConnectorTestUtil + +import org.specs2.mutable._ + +/** Test the ExampleJsonConnector */ +class ExampleJsonConnectorSpec extends Specification with ConnectorTestUtil { + + "ExampleJsonConnector" should { + + "convert userAction to Event JSON" in { + // webhooks input + val userAction = """ + { + "type": "userAction" + "userId": "as34smg4", + "event": "do_something", + "context": { + "ip": "24.5.68.47", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherProperty1": 100, + "anotherProperty2": "optional1", + "timestamp": "2015-01-02T00:30:12.984Z" + } + """ + + // expected converted Event JSON + val expected = """ + { + "event": "do_something", + "entityType": "user", + "entityId": "as34smg4", + "properties": { + "context": { + "ip": "24.5.68.47", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherProperty1": 100, + "anotherProperty2": "optional1" + } + "eventTime": "2015-01-02T00:30:12.984Z" + } + """ + + check(ExampleJsonConnector, userAction, expected) + } + + "convert userAction without optional field to Event JSON" in { + // webhooks input + val userAction = """ + { + "type": "userAction" + "userId": "as34smg4", + "event": "do_something", + "anotherProperty1": 100, + "timestamp": "2015-01-02T00:30:12.984Z" + } + """ + + // expected converted Event JSON + val expected = """ + { + "event": "do_something", + "entityType": "user", + "entityId": "as34smg4", + "properties": { + "anotherProperty1": 100, + } + "eventTime": "2015-01-02T00:30:12.984Z" + } + """ + + check(ExampleJsonConnector, userAction, expected) + } + + "convert userActionItem to Event JSON" in { + // webhooks input + val userActionItem = """ + { + "type": "userActionItem" + "userId": "as34smg4", + "event": "do_something_on", + "itemId": "kfjd312bc", + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherPropertyA": 4.567 + "anotherPropertyB": false + "timestamp": "2015-01-15T04:20:23.567Z" + } + """ + + // expected converted Event JSON + val expected = """ + { + "event": "do_something_on", + "entityType": "user", + "entityId": "as34smg4", + "targetEntityType": "item", + "targetEntityId": "kfjd312bc" + "properties": { + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + }, + "anotherPropertyA": 4.567 + "anotherPropertyB": false + } + "eventTime": "2015-01-15T04:20:23.567Z" + } + """ + + check(ExampleJsonConnector, userActionItem, expected) + } + + "convert userActionItem without optional fields to Event JSON" in { + // webhooks input + val userActionItem = """ + { + "type": "userActionItem" + "userId": "as34smg4", + "event": "do_something_on", + "itemId": "kfjd312bc", + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + } + "timestamp": "2015-01-15T04:20:23.567Z" + } + """ + + // expected converted Event JSON + val expected = """ + { + "event": "do_something_on", + "entityType": "user", + "entityId": "as34smg4", + "targetEntityType": "item", + "targetEntityId": "kfjd312bc" + "properties": { + "context": { + "ip": "1.23.4.56", + "prop1": 2.345 + "prop2": "value1" + } + } + "eventTime": "2015-01-15T04:20:23.567Z" + } + """ + + check(ExampleJsonConnector, userActionItem, expected) + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/webhooks/mailchimp/MailChimpConnectorSpec.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/webhooks/mailchimp/MailChimpConnectorSpec.scala b/data/src/test/scala/org/apache/predictionio/data/webhooks/mailchimp/MailChimpConnectorSpec.scala new file mode 100644 index 0000000..854c9dd --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/webhooks/mailchimp/MailChimpConnectorSpec.scala @@ -0,0 +1,254 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.webhooks.mailchimp + +import org.apache.predictionio.data.webhooks.ConnectorTestUtil + +import org.specs2.mutable._ + +class MailChimpConnectorSpec extends Specification with ConnectorTestUtil { + + // TOOD: test other events + // TODO: test different optional fields + + "MailChimpConnector" should { + + "convert subscribe to event JSON" in { + + val subscribe = Map( + "type" -> "subscribe", + "fired_at" -> "2009-03-26 21:35:57", + "data[id]" -> "8a25ff1d98", + "data[list_id]" -> "a6b5da1054", + "data[email]" -> "[email protected]", + "data[email_type]" -> "html", + "data[merges][EMAIL]" -> "[email protected]", + "data[merges][FNAME]" -> "MailChimp", + "data[merges][LNAME]" -> "API", + "data[merges][INTERESTS]" -> "Group1,Group2", //optional + "data[ip_opt]" -> "10.20.10.30", + "data[ip_signup]" -> "10.20.10.30" + ) + + val expected = """ + { + "event" : "subscribe", + "entityType" : "user", + "entityId" : "8a25ff1d98", + "targetEntityType" : "list", + "targetEntityId" : "a6b5da1054", + "properties" : { + "email" : "[email protected]", + "email_type" : "html", + "merges" : { + "EMAIL" : "[email protected]", + "FNAME" : "MailChimp", + "LNAME" : "API" + "INTERESTS" : "Group1,Group2" + }, + "ip_opt" : "10.20.10.30", + "ip_signup" : "10.20.10.30" + }, + "eventTime" : "2009-03-26T21:35:57.000Z" + } + """ + + check(MailChimpConnector, subscribe, expected) + } + + //check unsubscribe to event Json + "convert unsubscribe to event JSON" in { + + val unsubscribe = Map( + "type" -> "unsubscribe", + "fired_at" -> "2009-03-26 21:40:57", + "data[action]" -> "unsub", + "data[reason]" -> "manual", + "data[id]" -> "8a25ff1d98", + "data[list_id]" -> "a6b5da1054", + "data[email]" -> "[email protected]", + "data[email_type]" -> "html", + "data[merges][EMAIL]" -> "[email protected]", + "data[merges][FNAME]" -> "MailChimp", + "data[merges][LNAME]" -> "API", + "data[merges][INTERESTS]" -> "Group1,Group2", //optional + "data[ip_opt]" -> "10.20.10.30", + "data[campaign_id]" -> "cb398d21d2" + ) + + val expected = """ + { + "event" : "unsubscribe", + "entityType" : "user", + "entityId" : "8a25ff1d98", + "targetEntityType" : "list", + "targetEntityId" : "a6b5da1054", + "properties" : { + "action" : "unsub", + "reason" : "manual", + "email" : "[email protected]", + "email_type" : "html", + "merges" : { + "EMAIL" : "[email protected]", + "FNAME" : "MailChimp", + "LNAME" : "API" + "INTERESTS" : "Group1,Group2" + }, + "ip_opt" : "10.20.10.30", + "campaign_id" : "cb398d21d2" + }, + "eventTime" : "2009-03-26T21:40:57.000Z" + } + """ + + check(MailChimpConnector, unsubscribe, expected) + } + + //check profile update to event Json + "convert profile update to event JSON" in { + + val profileUpdate = Map( + "type" -> "profile", + "fired_at" -> "2009-03-26 21:31:21", + "data[id]" -> "8a25ff1d98", + "data[list_id]" -> "a6b5da1054", + "data[email]" -> "[email protected]", + "data[email_type]" -> "html", + "data[merges][EMAIL]" -> "[email protected]", + "data[merges][FNAME]" -> "MailChimp", + "data[merges][LNAME]" -> "API", + "data[merges][INTERESTS]" -> "Group1,Group2", //optional + "data[ip_opt]" -> "10.20.10.30" + ) + + val expected = """ + { + "event" : "profile", + "entityType" : "user", + "entityId" : "8a25ff1d98", + "targetEntityType" : "list", + "targetEntityId" : "a6b5da1054", + "properties" : { + "email" : "[email protected]", + "email_type" : "html", + "merges" : { + "EMAIL" : "[email protected]", + "FNAME" : "MailChimp", + "LNAME" : "API" + "INTERESTS" : "Group1,Group2" + }, + "ip_opt" : "10.20.10.30" + }, + "eventTime" : "2009-03-26T21:31:21.000Z" + } + """ + + check(MailChimpConnector, profileUpdate, expected) + } + + //check email update to event Json + "convert email update to event JSON" in { + + val emailUpdate = Map( + "type" -> "upemail", + "fired_at" -> "2009-03-26 22:15:09", + "data[list_id]" -> "a6b5da1054", + "data[new_id]" -> "51da8c3259", + "data[new_email]" -> "[email protected]", + "data[old_email]" -> "[email protected]" + ) + + val expected = """ + { + "event" : "upemail", + "entityType" : "user", + "entityId" : "51da8c3259", + "targetEntityType" : "list", + "targetEntityId" : "a6b5da1054", + "properties" : { + "new_email" : "[email protected]", + "old_email" : "[email protected]" + }, + "eventTime" : "2009-03-26T22:15:09.000Z" + } + """ + + check(MailChimpConnector, emailUpdate, expected) + } + + //check cleaned email to event Json + "convert cleaned email to event JSON" in { + + val cleanedEmail = Map( + "type" -> "cleaned", + "fired_at" -> "2009-03-26 22:01:00", + "data[list_id]" -> "a6b5da1054", + "data[campaign_id]" -> "4fjk2ma9xd", + "data[reason]" -> "hard", + "data[email]" -> "[email protected]" + ) + + val expected = """ + { + "event" : "cleaned", + "entityType" : "list", + "entityId" : "a6b5da1054", + "properties" : { + "campaignId" : "4fjk2ma9xd", + "reason" : "hard", + "email" : "[email protected]" + }, + "eventTime" : "2009-03-26T22:01:00.000Z" + } + """ + + check(MailChimpConnector, cleanedEmail, expected) + } + + //check campaign sending status to event Json + "convert campaign sending status to event JSON" in { + + val campaign = Map( + "type" -> "campaign", + "fired_at" -> "2009-03-26 22:15:09", + "data[id]" -> "5aa2102003", + "data[subject]" -> "Test Campaign Subject", + "data[status]" -> "sent", + "data[reason]" -> "", + "data[list_id]" -> "a6b5da1054" + ) + + val expected = """ + { + "event" : "campaign", + "entityType" : "campaign", + "entityId" : "5aa2102003", + "targetEntityType" : "list", + "targetEntityId" : "a6b5da1054", + "properties" : { + "subject" : "Test Campaign Subject", + "status" : "sent", + "reason" : "" + }, + "eventTime" : "2009-03-26T22:15:09.000Z" + } + """ + + check(MailChimpConnector, campaign, expected) + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/data/src/test/scala/org/apache/predictionio/data/webhooks/segmentio/SegmentIOConnectorSpec.scala ---------------------------------------------------------------------- diff --git a/data/src/test/scala/org/apache/predictionio/data/webhooks/segmentio/SegmentIOConnectorSpec.scala b/data/src/test/scala/org/apache/predictionio/data/webhooks/segmentio/SegmentIOConnectorSpec.scala new file mode 100644 index 0000000..de92ecd --- /dev/null +++ b/data/src/test/scala/org/apache/predictionio/data/webhooks/segmentio/SegmentIOConnectorSpec.scala @@ -0,0 +1,335 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.data.webhooks.segmentio + +import org.apache.predictionio.data.webhooks.ConnectorTestUtil + +import org.specs2.mutable._ + +class SegmentIOConnectorSpec extends Specification with ConnectorTestUtil { + + // TODO: test different optional fields + + val commonFields = + s""" + | "anonymous_id": "id", + | "sent_at": "sendAt", + | "version": "2", + """.stripMargin + + "SegmentIOConnector" should { + + "convert group with context to event JSON" in { + val context = + """ + | "context": { + | "app": { + | "name": "InitechGlobal", + | "version": "545", + | "build": "3.0.1.545" + | }, + | "campaign": { + | "name": "TPS Innovation Newsletter", + | "source": "Newsletter", + | "medium": "email", + | "term": "tps reports", + | "content": "image link" + | }, + | "device": { + | "id": "B5372DB0-C21E-11E4-8DFC-AA07A5B093DB", + | "advertising_id": "7A3CBEA0-BDF5-11E4-8DFC-AA07A5B093DB", + | "ad_tracking_enabled": true, + | "manufacturer": "Apple", + | "model": "iPhone7,2", + | "name": "maguro", + | "type": "ios", + | "token": "ff15bc0c20c4aa6cd50854ff165fd265c838e5405bfeb9571066395b8c9da449" + | }, + | "ip": "8.8.8.8", + | "library": { + | "name": "analytics-ios", + | "version": "1.8.0" + | }, + | "network": { + | "bluetooth": false, + | "carrier": "T-Mobile NL", + | "cellular": true, + | "wifi": false + | }, + | "location": { + | "city": "San Francisco", + | "country": "United States", + | "latitude": 40.2964197, + | "longitude": -76.9411617, + | "speed": 0 + | }, + | "os": { + | "name": "iPhone OS", + | "version": "8.1.3" + | }, + | "referrer": { + | "id": "ABCD582CDEFFFF01919", + | "type": "dataxu" + | }, + | "screen": { + | "width": 320, + | "height": 568, + | "density": 2 + | }, + | "timezone": "Europe/Amsterdam", + | "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_5)" + | } + """.stripMargin + + val group = + s""" + |{ $commonFields + | "type": "group", + | "group_id": "groupId", + | "user_id": "userIdValue", + | "timestamp" : "2012-12-02T00:30:08.276Z", + | "traits": { + | "name": "groupName", + | "employees": 329, + | }, + | $context + |} + """.stripMargin + + val expected = + s""" + |{ + | "event": "group", + | "entityType": "user", + | "entityId": "userIdValue", + | "properties": { + | $context, + | "group_id": "groupId", + | "traits": { + | "name": "groupName", + | "employees": 329 + | }, + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, group, expected) + } + + "convert group to event JSON" in { + val group = + s""" + |{ $commonFields + | "type": "group", + | "group_id": "groupId", + | "user_id": "userIdValue", + | "timestamp" : "2012-12-02T00:30:08.276Z", + | "traits": { + | "name": "groupName", + | "employees": 329, + | } + |} + """.stripMargin + + val expected = + """ + |{ + | "event": "group", + | "entityType": "user", + | "entityId": "userIdValue", + | "properties": { + | "group_id": "groupId", + | "traits": { + | "name": "groupName", + | "employees": 329 + | } + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, group, expected) + } + + "convert screen to event JSON" in { + val screen = + s""" + |{ $commonFields + | "type": "screen", + | "name": "screenName", + | "user_id": "userIdValue", + | "timestamp" : "2012-12-02T00:30:08.276Z", + | "properties": { + | "variation": "screenVariation" + | } + |} + """.stripMargin + + val expected = + """ + |{ + | "event": "screen", + | "entityType": "user", + | "entityId": "userIdValue", + | "properties": { + | "properties": { + | "variation": "screenVariation" + | }, + | "name": "screenName" + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, screen, expected) + } + + "convert page to event JSON" in { + val page = + s""" + |{ $commonFields + | "type": "page", + | "name": "pageName", + | "user_id": "userIdValue", + | "timestamp" : "2012-12-02T00:30:08.276Z", + | "properties": { + | "title": "pageTitle", + | "url": "pageUrl" + | } + |} + """.stripMargin + + val expected = + """ + |{ + | "event": "page", + | "entityType": "user", + | "entityId": "userIdValue", + | "properties": { + | "properties": { + | "title": "pageTitle", + | "url": "pageUrl" + | }, + | "name": "pageName" + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, page, expected) + } + + "convert alias to event JSON" in { + val alias = + s""" + |{ $commonFields + | "type": "alias", + | "previous_id": "previousIdValue", + | "user_id": "userIdValue", + | "timestamp" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + val expected = + """ + |{ + | "event": "alias", + | "entityType": "user", + | "entityId": "userIdValue", + | "properties": { + | "previous_id" : "previousIdValue" + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, alias, expected) + } + + "convert track to event JSON" in { + val track = + s""" + |{ $commonFields + | "user_id": "some_user_id", + | "type": "track", + | "event": "Registered", + | "timestamp" : "2012-12-02T00:30:08.276Z", + | "properties": { + | "plan": "Pro Annual", + | "accountType" : "Facebook" + | } + |} + """.stripMargin + + val expected = + """ + |{ + | "event": "track", + | "entityType": "user", + | "entityId": "some_user_id", + | "properties": { + | "event": "Registered", + | "properties": { + | "plan": "Pro Annual", + | "accountType": "Facebook" + | } + | }, + | "eventTime" : "2012-12-02T00:30:08.276Z" + |} + """.stripMargin + + check(SegmentIOConnector, track, expected) + } + + "convert identify to event JSON" in { + val identify = s""" + { $commonFields + "type" : "identify", + "user_id" : "019mr8mf4r", + "traits" : { + "email" : "[email protected]", + "name" : "Achilles", + "subscription_plan" : "Premium", + "friendCount" : 29 + }, + "timestamp" : "2012-12-02T00:30:08.276Z" + } + """ + + val expected = """ + { + "event" : "identify", + "entityType": "user", + "entityId" : "019mr8mf4r", + "properties" : { + "traits" : { + "email" : "[email protected]", + "name" : "Achilles", + "subscription_plan" : "Premium", + "friendCount" : 29 + } + }, + "eventTime" : "2012-12-02T00:30:08.276Z" + } + """ + + check(SegmentIOConnector, identify, expected) + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/e2/engine/BinaryVectorizer.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/e2/engine/BinaryVectorizer.scala b/e2/src/main/scala/io/prediction/e2/engine/BinaryVectorizer.scala deleted file mode 100644 index 6c0d5d3..0000000 --- a/e2/src/main/scala/io/prediction/e2/engine/BinaryVectorizer.scala +++ /dev/null @@ -1,61 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.engine - -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.linalg.Vector -import scala.collection.immutable.HashMap -import scala.collection.immutable.HashSet - -class BinaryVectorizer(propertyMap : HashMap[(String, String), Int]) -extends Serializable { - - val properties: Array[(String, String)] = propertyMap.toArray.sortBy(_._2).map(_._1) - val numFeatures = propertyMap.size - - override def toString: String = { - s"BinaryVectorizer($numFeatures): " + properties.map(e => s"(${e._1}, ${e._2})").mkString(",") - } - - def toBinary(map : Array[(String, String)]) : Vector = { - val mapArr : Seq[(Int, Double)] = map.flatMap( - e => propertyMap.get(e).map(idx => (idx, 1.0)) - ) - - Vectors.sparse(numFeatures, mapArr) - } -} - - -object BinaryVectorizer { - def apply (input : RDD[HashMap[String, String]], properties : HashSet[String]) - : BinaryVectorizer = { - new BinaryVectorizer(HashMap( - input.flatMap(identity) - .filter(e => properties.contains(e._1)) - .distinct - .collect - .zipWithIndex : _* - )) - } - - def apply(input: Seq[(String, String)]): BinaryVectorizer = { - val indexed: Seq[((String, String), Int)] = input.zipWithIndex - new BinaryVectorizer(HashMap(indexed:_*)) - } -} - http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/e2/engine/CategoricalNaiveBayes.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/e2/engine/CategoricalNaiveBayes.scala b/e2/src/main/scala/io/prediction/e2/engine/CategoricalNaiveBayes.scala deleted file mode 100644 index c598519..0000000 --- a/e2/src/main/scala/io/prediction/e2/engine/CategoricalNaiveBayes.scala +++ /dev/null @@ -1,176 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.engine - -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -/** - * Class for training a naive Bayes model with categorical variables - */ -object CategoricalNaiveBayes { - /** - * Train with data points and return the model - * - * @param points training data points - */ - def train(points: RDD[LabeledPoint]): CategoricalNaiveBayesModel = { - val labelCountFeatureLikelihoods = points.map { p => - (p.label, p.features) - }.combineByKey[(Long, Array[Map[String, Long]])]( - createCombiner = - (features: Array[String]) => { - val featureCounts = features.map { feature => - Map[String, Long]().withDefaultValue(0L).updated(feature, 1L) - } - - (1L, featureCounts) - }, - mergeValue = - (c: (Long, Array[Map[String, Long]]), features: Array[String]) => { - (c._1 + 1L, c._2.zip(features).map { case (m, feature) => - m.updated(feature, m(feature) + 1L) - }) - }, - mergeCombiners = - ( - c1: (Long, Array[Map[String, Long]]), - c2: (Long, Array[Map[String, Long]])) => { - val labelCount1 = c1._1 - val labelCount2 = c2._1 - val featureCounts1 = c1._2 - val featureCounts2 = c2._2 - - (labelCount1 + labelCount2, - featureCounts1.zip(featureCounts2).map { case (m1, m2) => - m2 ++ m2.map { case (k, v) => k -> (v + m2(k))} - }) - } - ).mapValues { case (labelCount, featureCounts) => - val featureLikelihoods = featureCounts.map { featureCount => - // mapValues does not return a serializable map - featureCount.mapValues(count => math.log(count.toDouble / labelCount)) - .map(identity) - } - - (labelCount, featureLikelihoods) - }.collect().toMap - - val noOfPoints = labelCountFeatureLikelihoods.map(_._2._1).sum - val priors = - labelCountFeatureLikelihoods.mapValues { countFeatureLikelihoods => - math.log(countFeatureLikelihoods._1 / noOfPoints.toDouble) - } - val likelihoods = labelCountFeatureLikelihoods.mapValues(_._2) - - CategoricalNaiveBayesModel(priors, likelihoods) - } -} - -/** - * Model for naive Bayes classifiers with categorical variables. - * - * @param priors log prior probabilities - * @param likelihoods log likelihood probabilities - */ -case class CategoricalNaiveBayesModel( - priors: Map[String, Double], - likelihoods: Map[String, Array[Map[String, Double]]]) extends Serializable { - - val featureCount = likelihoods.head._2.size - - /** - * Calculate the log score of having the given features and label - * - * @param point label and features - * @param defaultLikelihood a function that calculates the likelihood when a - * feature value is not present. The input to the - * function is the other feature value likelihoods. - * @return log score when label is present. None otherwise. - */ - def logScore( - point: LabeledPoint, - defaultLikelihood: (Seq[Double]) => Double = ls => Double.NegativeInfinity - ): Option[Double] = { - val label = point.label - val features = point.features - - if (!priors.contains(label)) { - None - } else { - Some(logScoreInternal(label, features, defaultLikelihood)) - } - } - - private def logScoreInternal( - label: String, - features: Array[String], - defaultLikelihood: (Seq[Double]) => Double = ls => Double.NegativeInfinity - ): Double = { - - val prior = priors(label) - val likelihood = likelihoods(label) - - val likelihoodScores = features.zip(likelihood).map { - case (feature, featureLikelihoods) => - featureLikelihoods.getOrElse( - feature, - defaultLikelihood(featureLikelihoods.values.toSeq) - ) - } - - prior + likelihoodScores.sum - } - - /** - * Return the label that yields the highest score - * - * @param features features for classification - * - */ - def predict(features: Array[String]): String = { - priors.keySet.map { label => - (label, logScoreInternal(label, features)) - }.toSeq - .sortBy(_._2)(Ordering.Double.reverse) - .take(1) - .head - ._1 - } -} - -/** - * Class that represents the features and labels of a data point. - * - * @param label Label of this data point - * @param features Features of this data point - */ -case class LabeledPoint(label: String, features: Array[String]) { - override def toString: String = { - val featuresString = features.mkString("[", ",", "]") - - s"($label, $featuresString)" - } - - override def equals(other: Any): Boolean = other match { - case that: LabeledPoint => that.toString == this.toString - case _ => false - } - - override def hashCode(): Int = { - this.toString.hashCode - } - -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/e2/engine/MarkovChain.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/e2/engine/MarkovChain.scala b/e2/src/main/scala/io/prediction/e2/engine/MarkovChain.scala deleted file mode 100644 index 4c992f5..0000000 --- a/e2/src/main/scala/io/prediction/e2/engine/MarkovChain.scala +++ /dev/null @@ -1,89 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.engine - -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix -import org.apache.spark.mllib.linalg.{SparseVector, Vectors} -import org.apache.spark.rdd.RDD - -/** - * Class for training a Markov Chain model - */ -object MarkovChain { - /** - * Train a Markov Chain model - * - * @param matrix Tally of all state transitions - * @param topN Use the top-N tally for each state - */ - def train(matrix: CoordinateMatrix, topN: Int): MarkovChainModel = { - val noOfStates = matrix.numCols().toInt - val transitionVectors = matrix.entries - .keyBy(_.i.toInt) - .groupByKey() - .mapValues { rowEntries => - val total = rowEntries.map(_.value).sum - val sortedTopN = rowEntries.toSeq - .sortBy(_.value)(Ordering.Double.reverse) - .take(topN) - .map(me => (me.j.toInt, me.value / total)) - .sortBy(_._1) - - new SparseVector( - noOfStates, - sortedTopN.map(_._1).toArray, - sortedTopN.map(_._2).toArray) - } - - new MarkovChainModel( - transitionVectors, - topN) - } -} - -/** - * Markov Chain model - * - * @param transitionVectors transition vectors - * @param n top N used to construct the model - */ -case class MarkovChainModel( - transitionVectors: RDD[(Int, SparseVector)], - n: Int) { - - /** - * Calculate the probabilities of the next state - * - * @param currentState probabilities of the current state - */ - def predict(currentState: Seq[Double]): Seq[Double] = { - // multiply the input with transition matrix row by row - val nextStateVectors = transitionVectors.map { case (rowIndex, vector) => - val values = vector.indices.map { index => - vector(index) * currentState(rowIndex) - } - - Vectors.sparse(currentState.size, vector.indices, values) - }.collect() - - // sum up to get the total probabilities - (0 until currentState.size).map { index => - nextStateVectors.map { vector => - vector(index) - }.sum - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/e2/evaluation/CrossValidation.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/e2/evaluation/CrossValidation.scala b/e2/src/main/scala/io/prediction/e2/evaluation/CrossValidation.scala deleted file mode 100644 index 8b482bd..0000000 --- a/e2/src/main/scala/io/prediction/e2/evaluation/CrossValidation.scala +++ /dev/null @@ -1,64 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prediction.e2.evaluation - -import scala.reflect.ClassTag -import org.apache.spark.rdd.RDD - -/** Common helper functions */ -object CommonHelperFunctions { - - /** Split a data set into evalK folds for crossvalidation. - * Apply to data sets supplied to evaluation. - * - * @tparam D Data point class. - * @tparam TD Training data class. - * @tparam EI Evaluation Info class. - * @tparam Q Input query class. - * @tparam A Actual value class. - */ - - def splitData[D: ClassTag, TD, EI, Q, A]( - - evalK: Int, - dataset: RDD[D], - evaluatorInfo: EI, - trainingDataCreator: RDD[D] => TD, - queryCreator: D => Q, - actualCreator: D => A): Seq[(TD, EI, RDD[(Q, A)])] = { - - val indexedPoints = dataset.zipWithIndex - - def selectPoint(foldIdx: Int, pt: D, idx: Long, k: Int, isTraining: Boolean): Option[D] = { - if ((idx % k == foldIdx) ^ isTraining) Some(pt) - else None - } - - (0 until evalK).map { foldIdx => - val trainingPoints = indexedPoints.flatMap { case(pt, idx) => - selectPoint(foldIdx, pt, idx, evalK, true) - } - val testingPoints = indexedPoints.flatMap { case(pt, idx) => - selectPoint(foldIdx, pt, idx, evalK, false) - } - - ( - trainingDataCreator(trainingPoints), - evaluatorInfo, - testingPoints.map { d => (queryCreator(d), actualCreator(d)) } - ) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/e2/package.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/e2/package.scala b/e2/src/main/scala/io/prediction/e2/package.scala deleted file mode 100644 index 9f5491a..0000000 --- a/e2/src/main/scala/io/prediction/e2/package.scala +++ /dev/null @@ -1,22 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.e2 - -/** Collection of engine libraries that have no dependency on PredictionIO */ -package object engine {} - -/** Collection of evaluation libraries that have no dependency on PredictionIO */ -package object evaluation {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/io/prediction/package.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/io/prediction/package.scala b/e2/src/main/scala/io/prediction/package.scala deleted file mode 100644 index 9628b5d..0000000 --- a/e2/src/main/scala/io/prediction/package.scala +++ /dev/null @@ -1,21 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction - -/** Independent library of code that is useful for engine development and - * evaluation - */ -package object e2 {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/e2/engine/BinaryVectorizer.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/e2/engine/BinaryVectorizer.scala b/e2/src/main/scala/org/apache/predictionio/e2/engine/BinaryVectorizer.scala new file mode 100644 index 0000000..d831718 --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/e2/engine/BinaryVectorizer.scala @@ -0,0 +1,61 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.engine + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.Vector +import scala.collection.immutable.HashMap +import scala.collection.immutable.HashSet + +class BinaryVectorizer(propertyMap : HashMap[(String, String), Int]) +extends Serializable { + + val properties: Array[(String, String)] = propertyMap.toArray.sortBy(_._2).map(_._1) + val numFeatures = propertyMap.size + + override def toString: String = { + s"BinaryVectorizer($numFeatures): " + properties.map(e => s"(${e._1}, ${e._2})").mkString(",") + } + + def toBinary(map : Array[(String, String)]) : Vector = { + val mapArr : Seq[(Int, Double)] = map.flatMap( + e => propertyMap.get(e).map(idx => (idx, 1.0)) + ) + + Vectors.sparse(numFeatures, mapArr) + } +} + + +object BinaryVectorizer { + def apply (input : RDD[HashMap[String, String]], properties : HashSet[String]) + : BinaryVectorizer = { + new BinaryVectorizer(HashMap( + input.flatMap(identity) + .filter(e => properties.contains(e._1)) + .distinct + .collect + .zipWithIndex : _* + )) + } + + def apply(input: Seq[(String, String)]): BinaryVectorizer = { + val indexed: Seq[((String, String), Int)] = input.zipWithIndex + new BinaryVectorizer(HashMap(indexed:_*)) + } +} + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayes.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayes.scala b/e2/src/main/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayes.scala new file mode 100644 index 0000000..7944bbc --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayes.scala @@ -0,0 +1,176 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.engine + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD + +/** + * Class for training a naive Bayes model with categorical variables + */ +object CategoricalNaiveBayes { + /** + * Train with data points and return the model + * + * @param points training data points + */ + def train(points: RDD[LabeledPoint]): CategoricalNaiveBayesModel = { + val labelCountFeatureLikelihoods = points.map { p => + (p.label, p.features) + }.combineByKey[(Long, Array[Map[String, Long]])]( + createCombiner = + (features: Array[String]) => { + val featureCounts = features.map { feature => + Map[String, Long]().withDefaultValue(0L).updated(feature, 1L) + } + + (1L, featureCounts) + }, + mergeValue = + (c: (Long, Array[Map[String, Long]]), features: Array[String]) => { + (c._1 + 1L, c._2.zip(features).map { case (m, feature) => + m.updated(feature, m(feature) + 1L) + }) + }, + mergeCombiners = + ( + c1: (Long, Array[Map[String, Long]]), + c2: (Long, Array[Map[String, Long]])) => { + val labelCount1 = c1._1 + val labelCount2 = c2._1 + val featureCounts1 = c1._2 + val featureCounts2 = c2._2 + + (labelCount1 + labelCount2, + featureCounts1.zip(featureCounts2).map { case (m1, m2) => + m2 ++ m2.map { case (k, v) => k -> (v + m2(k))} + }) + } + ).mapValues { case (labelCount, featureCounts) => + val featureLikelihoods = featureCounts.map { featureCount => + // mapValues does not return a serializable map + featureCount.mapValues(count => math.log(count.toDouble / labelCount)) + .map(identity) + } + + (labelCount, featureLikelihoods) + }.collect().toMap + + val noOfPoints = labelCountFeatureLikelihoods.map(_._2._1).sum + val priors = + labelCountFeatureLikelihoods.mapValues { countFeatureLikelihoods => + math.log(countFeatureLikelihoods._1 / noOfPoints.toDouble) + } + val likelihoods = labelCountFeatureLikelihoods.mapValues(_._2) + + CategoricalNaiveBayesModel(priors, likelihoods) + } +} + +/** + * Model for naive Bayes classifiers with categorical variables. + * + * @param priors log prior probabilities + * @param likelihoods log likelihood probabilities + */ +case class CategoricalNaiveBayesModel( + priors: Map[String, Double], + likelihoods: Map[String, Array[Map[String, Double]]]) extends Serializable { + + val featureCount = likelihoods.head._2.size + + /** + * Calculate the log score of having the given features and label + * + * @param point label and features + * @param defaultLikelihood a function that calculates the likelihood when a + * feature value is not present. The input to the + * function is the other feature value likelihoods. + * @return log score when label is present. None otherwise. + */ + def logScore( + point: LabeledPoint, + defaultLikelihood: (Seq[Double]) => Double = ls => Double.NegativeInfinity + ): Option[Double] = { + val label = point.label + val features = point.features + + if (!priors.contains(label)) { + None + } else { + Some(logScoreInternal(label, features, defaultLikelihood)) + } + } + + private def logScoreInternal( + label: String, + features: Array[String], + defaultLikelihood: (Seq[Double]) => Double = ls => Double.NegativeInfinity + ): Double = { + + val prior = priors(label) + val likelihood = likelihoods(label) + + val likelihoodScores = features.zip(likelihood).map { + case (feature, featureLikelihoods) => + featureLikelihoods.getOrElse( + feature, + defaultLikelihood(featureLikelihoods.values.toSeq) + ) + } + + prior + likelihoodScores.sum + } + + /** + * Return the label that yields the highest score + * + * @param features features for classification + * + */ + def predict(features: Array[String]): String = { + priors.keySet.map { label => + (label, logScoreInternal(label, features)) + }.toSeq + .sortBy(_._2)(Ordering.Double.reverse) + .take(1) + .head + ._1 + } +} + +/** + * Class that represents the features and labels of a data point. + * + * @param label Label of this data point + * @param features Features of this data point + */ +case class LabeledPoint(label: String, features: Array[String]) { + override def toString: String = { + val featuresString = features.mkString("[", ",", "]") + + s"($label, $featuresString)" + } + + override def equals(other: Any): Boolean = other match { + case that: LabeledPoint => that.toString == this.toString + case _ => false + } + + override def hashCode(): Int = { + this.toString.hashCode + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/e2/engine/MarkovChain.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/e2/engine/MarkovChain.scala b/e2/src/main/scala/org/apache/predictionio/e2/engine/MarkovChain.scala new file mode 100644 index 0000000..41a070d --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/e2/engine/MarkovChain.scala @@ -0,0 +1,89 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.engine + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix +import org.apache.spark.mllib.linalg.{SparseVector, Vectors} +import org.apache.spark.rdd.RDD + +/** + * Class for training a Markov Chain model + */ +object MarkovChain { + /** + * Train a Markov Chain model + * + * @param matrix Tally of all state transitions + * @param topN Use the top-N tally for each state + */ + def train(matrix: CoordinateMatrix, topN: Int): MarkovChainModel = { + val noOfStates = matrix.numCols().toInt + val transitionVectors = matrix.entries + .keyBy(_.i.toInt) + .groupByKey() + .mapValues { rowEntries => + val total = rowEntries.map(_.value).sum + val sortedTopN = rowEntries.toSeq + .sortBy(_.value)(Ordering.Double.reverse) + .take(topN) + .map(me => (me.j.toInt, me.value / total)) + .sortBy(_._1) + + new SparseVector( + noOfStates, + sortedTopN.map(_._1).toArray, + sortedTopN.map(_._2).toArray) + } + + new MarkovChainModel( + transitionVectors, + topN) + } +} + +/** + * Markov Chain model + * + * @param transitionVectors transition vectors + * @param n top N used to construct the model + */ +case class MarkovChainModel( + transitionVectors: RDD[(Int, SparseVector)], + n: Int) { + + /** + * Calculate the probabilities of the next state + * + * @param currentState probabilities of the current state + */ + def predict(currentState: Seq[Double]): Seq[Double] = { + // multiply the input with transition matrix row by row + val nextStateVectors = transitionVectors.map { case (rowIndex, vector) => + val values = vector.indices.map { index => + vector(index) * currentState(rowIndex) + } + + Vectors.sparse(currentState.size, vector.indices, values) + }.collect() + + // sum up to get the total probabilities + (0 until currentState.size).map { index => + nextStateVectors.map { vector => + vector(index) + }.sum + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/e2/evaluation/CrossValidation.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/e2/evaluation/CrossValidation.scala b/e2/src/main/scala/org/apache/predictionio/e2/evaluation/CrossValidation.scala new file mode 100644 index 0000000..d2e1d6a --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/e2/evaluation/CrossValidation.scala @@ -0,0 +1,64 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.predictionio.e2.evaluation + +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +/** Common helper functions */ +object CommonHelperFunctions { + + /** Split a data set into evalK folds for crossvalidation. + * Apply to data sets supplied to evaluation. + * + * @tparam D Data point class. + * @tparam TD Training data class. + * @tparam EI Evaluation Info class. + * @tparam Q Input query class. + * @tparam A Actual value class. + */ + + def splitData[D: ClassTag, TD, EI, Q, A]( + + evalK: Int, + dataset: RDD[D], + evaluatorInfo: EI, + trainingDataCreator: RDD[D] => TD, + queryCreator: D => Q, + actualCreator: D => A): Seq[(TD, EI, RDD[(Q, A)])] = { + + val indexedPoints = dataset.zipWithIndex + + def selectPoint(foldIdx: Int, pt: D, idx: Long, k: Int, isTraining: Boolean): Option[D] = { + if ((idx % k == foldIdx) ^ isTraining) Some(pt) + else None + } + + (0 until evalK).map { foldIdx => + val trainingPoints = indexedPoints.flatMap { case(pt, idx) => + selectPoint(foldIdx, pt, idx, evalK, true) + } + val testingPoints = indexedPoints.flatMap { case(pt, idx) => + selectPoint(foldIdx, pt, idx, evalK, false) + } + + ( + trainingDataCreator(trainingPoints), + evaluatorInfo, + testingPoints.map { d => (queryCreator(d), actualCreator(d)) } + ) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/e2/package.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/e2/package.scala b/e2/src/main/scala/org/apache/predictionio/e2/package.scala new file mode 100644 index 0000000..c16e521 --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/e2/package.scala @@ -0,0 +1,22 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.e2 + +/** Collection of engine libraries that have no dependency on PredictionIO */ +package object engine {} + +/** Collection of evaluation libraries that have no dependency on PredictionIO */ +package object evaluation {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/main/scala/org/apache/predictionio/package.scala ---------------------------------------------------------------------- diff --git a/e2/src/main/scala/org/apache/predictionio/package.scala b/e2/src/main/scala/org/apache/predictionio/package.scala new file mode 100644 index 0000000..b480779 --- /dev/null +++ b/e2/src/main/scala/org/apache/predictionio/package.scala @@ -0,0 +1,21 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio + +/** Independent library of code that is useful for engine development and + * evaluation + */ +package object e2 {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/engine/BinaryVectorizerTest.scala ---------------------------------------------------------------------- diff --git a/e2/src/test/scala/io/prediction/e2/engine/BinaryVectorizerTest.scala b/e2/src/test/scala/io/prediction/e2/engine/BinaryVectorizerTest.scala deleted file mode 100644 index 5e6bc16..0000000 --- a/e2/src/test/scala/io/prediction/e2/engine/BinaryVectorizerTest.scala +++ /dev/null @@ -1,56 +0,0 @@ -/** Copyright 2015 TappingStone, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.prediction.e2.engine - -import io.prediction.e2.fixture.BinaryVectorizerFixture -import io.prediction.e2.fixture.SharedSparkContext -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.rdd.RDD -import org.scalatest.FlatSpec -import org.scalatest.Matchers -import scala.collection.immutable.HashMap - - -import scala.language.reflectiveCalls - -class BinaryVectorizerTest extends FlatSpec with Matchers with SharedSparkContext -with BinaryVectorizerFixture{ - - "toBinary" should "produce the following summed values:" in { - val testCase = BinaryVectorizer(sc.parallelize(base.maps), base.properties) - val vectorTwoA = testCase.toBinary(testArrays.twoA) - val vectorTwoB = testCase.toBinary(testArrays.twoB) - - - // Make sure vectors produced are the same size. - vectorTwoA.size should be (vectorTwoB.size) - - // // Test case for checking food value not listed in base.maps. - testCase.toBinary(testArrays.one).toArray.sum should be (1.0) - - // Test cases for making sure indices are preserved. - val sumOne = vecSum(vectorTwoA, vectorTwoB) - - exactly (1, sumOne) should be (2.0) - exactly (2,sumOne) should be (0.0) - exactly (2, sumOne) should be (1.0) - - val sumTwo = vecSum(Vectors.dense(sumOne), testCase.toBinary(testArrays.twoC)) - - exactly (3, sumTwo) should be (1.0) - } - -}
