0.11.0-incubating release
Project: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/commit/36995dfc Tree: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/tree/36995dfc Diff: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/diff/36995dfc Branch: refs/heads/master Commit: 36995dfce7581cb459456858651ebd2f846d62b6 Parents: ae23d8c Author: Donald Szeto <[email protected]> Authored: Thu May 4 11:48:55 2017 -0700 Committer: Donald Szeto <[email protected]> Committed: Thu May 4 11:48:55 2017 -0700 ---------------------------------------------------------------------- README.md | 19 +- build.sbt | 20 +- engine.json | 2 +- project/assembly.sbt | 2 +- project/build.properties | 1 + project/pio-build.sbt | 1 - .../org/example/recommendation/Algorithm.java | 409 +++++++++++++++++++ .../example/recommendation/AlgorithmParams.java | 74 ++++ .../org/example/recommendation/DataSource.java | 150 +++++++ .../recommendation/DataSourceParams.java | 15 + .../java/org/example/recommendation/Item.java | 31 ++ .../org/example/recommendation/ItemScore.java | 34 ++ .../java/org/example/recommendation/Model.java | 84 ++++ .../example/recommendation/PredictedResult.java | 23 ++ .../org/example/recommendation/Preparator.java | 12 + .../example/recommendation/PreparedData.java | 15 + .../java/org/example/recommendation/Query.java | 55 +++ .../recommendation/RecommendationEngine.java | 23 ++ .../org/example/recommendation/Serving.java | 12 + .../example/recommendation/TrainingData.java | 50 +++ .../java/org/example/recommendation/User.java | 30 ++ .../example/recommendation/UserItemEvent.java | 43 ++ .../recommendation/UserItemEventType.java | 5 + .../evaluation/EvaluationParameter.java | 28 ++ .../evaluation/EvaluationSpec.java | 28 ++ .../evaluation/PrecisionMetric.java | 62 +++ .../org/template/recommendation/Algorithm.java | 409 ------------------- .../recommendation/AlgorithmParams.java | 74 ---- .../org/template/recommendation/DataSource.java | 150 ------- .../recommendation/DataSourceParams.java | 15 - .../java/org/template/recommendation/Item.java | 31 -- .../org/template/recommendation/ItemScore.java | 34 -- .../java/org/template/recommendation/Model.java | 84 ---- .../recommendation/PredictedResult.java | 23 -- .../org/template/recommendation/Preparator.java | 12 - .../template/recommendation/PreparedData.java | 15 - .../java/org/template/recommendation/Query.java | 55 --- .../recommendation/RecommendationEngine.java | 23 -- .../org/template/recommendation/Serving.java | 12 - .../template/recommendation/TrainingData.java | 50 --- .../java/org/template/recommendation/User.java | 30 -- .../template/recommendation/UserItemEvent.java | 43 -- .../recommendation/UserItemEventType.java | 5 - .../evaluation/EvaluationParameter.java | 28 -- .../evaluation/EvaluationSpec.java | 28 -- .../evaluation/PrecisionMetric.java | 62 --- template.json | 2 +- 47 files changed, 1206 insertions(+), 1207 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/README.md ---------------------------------------------------------------------- diff --git a/README.md b/README.md index a4ac665..35df67d 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,18 @@ -# E-Commerce Recommendation Template +# E-Commerce Recommendation Template in Java ## Documentation -Please refer to http://docs.prediction.io/templates/javaecommercerecommendation/quickstart/ +Please refer to +http://predictionio.incubator.apache.org/templates/javaecommercerecommendation/quickstart/. ## Versions +### v0.11.0-incubating + +- Update to build with PredictionIO 0.11.0-incubating +- Rename Java package name +- Update SBT and plugin versions + ### v0.1.2 add "org.jblas" dependency in build.sbt @@ -19,13 +26,13 @@ Please refer to http://docs.prediction.io/templates/javaecommercerecommendation/ ## Development Notes -### import sample data +### Import Sample Data ``` $ python data/import_eventserver.py --access_key <your_access_key> ``` -### query +### Query normal: @@ -77,7 +84,7 @@ curl -H "Content-Type: application/json" \ http://localhost:8000/queries.json ``` -### handle new user +### Handle New User new user: @@ -120,7 +127,7 @@ curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \ ``` -## handle unavailable items +### Handle Unavailable Items Set the following items as unavailable (need to specify complete list each time when this list is changed): http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/build.sbt ---------------------------------------------------------------------- diff --git a/build.sbt b/build.sbt index 20e1346..36da38f 100644 --- a/build.sbt +++ b/build.sbt @@ -1,16 +1,8 @@ -import AssemblyKeys._ - -assemblySettings - -name := "barebone-template" - -organization := "io.prediction" +name := "template-java-ecom-recommender" libraryDependencies ++= Seq( - "io.prediction" %% "core" % pioVersion.value % "provided", - "org.apache.spark" %% "spark-core" % "1.3.0" % "provided", - "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided", - "org.scalatest" % "scalatest_2.10" % "2.2.1" % "test", - "com.google.guava" % "guava" % "12.0", - "org.jblas" % "jblas" % "1.2.4" -) + "org.apache.predictionio" %% "apache-predictionio-core" % "0.11.0-incubating" % "provided", + "org.apache.spark" %% "spark-core" % "1.3.0" % "provided", + "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided", + "com.google.guava" % "guava" % "12.0", + "org.jblas" % "jblas" % "1.2.4") http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/engine.json ---------------------------------------------------------------------- diff --git a/engine.json b/engine.json index 1f8ed0c..0f44544 100644 --- a/engine.json +++ b/engine.json @@ -1,7 +1,7 @@ { "id": "default", "description": "Default settings", - "engineFactory": "org.template.recommendation.RecommendationEngine", + "engineFactory": "org.example.recommendation.RecommendationEngine", "datasource": { "params" : { "appName": "javadase" http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/assembly.sbt ---------------------------------------------------------------------- diff --git a/project/assembly.sbt b/project/assembly.sbt index 54c3252..e17409e 100644 --- a/project/assembly.sbt +++ b/project/assembly.sbt @@ -1 +1 @@ -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.4") http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/build.properties ---------------------------------------------------------------------- diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..64317fd --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=0.13.15 http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/pio-build.sbt ---------------------------------------------------------------------- diff --git a/project/pio-build.sbt b/project/pio-build.sbt deleted file mode 100644 index 878fc0d..0000000 --- a/project/pio-build.sbt +++ /dev/null @@ -1 +0,0 @@ -addSbtPlugin("io.prediction" % "pio-build" % "0.9.0") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Algorithm.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Algorithm.java b/src/main/java/org/example/recommendation/Algorithm.java new file mode 100644 index 0000000..349e945 --- /dev/null +++ b/src/main/java/org/example/recommendation/Algorithm.java @@ -0,0 +1,409 @@ +package org.example.recommendation; + +import com.google.common.collect.Sets; +import org.apache.predictionio.controller.java.PJavaAlgorithm; +import org.apache.predictionio.data.storage.Event; +import org.apache.predictionio.data.store.java.LJavaEventStore; +import org.apache.predictionio.data.store.java.OptionHelper; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.rdd.RDD; +import org.jblas.DoubleMatrix; +import org.joda.time.DateTime; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; +import scala.Tuple2; +import scala.concurrent.duration.Duration; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> { + + private static final Logger logger = LoggerFactory.getLogger(Algorithm.class); + private final AlgorithmParams ap; + + public Algorithm(AlgorithmParams ap) { + this.ap = ap; + } + + @Override + public Model train(SparkContext sc, PreparedData preparedData) { + TrainingData data = preparedData.getTrainingData(); + + // user stuff + JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() { + @Override + public String call(Tuple2<String, User> idUser) throws Exception { + return idUser._1(); + } + }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { + @Override + public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { + return new Tuple2<>(element._1(), element._2().intValue()); + } + }); + final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap(); + + // item stuff + JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() { + @Override + public String call(Tuple2<String, Item> idItem) throws Exception { + return idItem._1(); + } + }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { + @Override + public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { + return new Tuple2<>(element._1(), element._2().intValue()); + } + }); + final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap(); + JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() { + @Override + public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception { + return element.swap(); + } + }); + final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap(); + + // ratings stuff + JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { + @Override + public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception { + Integer userIndex = userIndexMap.get(viewEvent.getUser()); + Integer itemIndex = itemIndexMap.get(viewEvent.getItem()); + + return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); + } + }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { + @Override + public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { + return (element != null); + } + }).reduceByKey(new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer integer, Integer integer2) throws Exception { + return integer + integer2; + } + }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() { + @Override + public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception { + return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue()); + } + }); + + if (ratings.isEmpty()) + throw new AssertionError("Please check if your events contain valid user and item ID."); + + // MLlib ALS stuff + MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed()); + JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { + @Override + public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { + return new Tuple2<>((Integer) element._1(), element._2()); + } + }); + JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { + @Override + public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { + return new Tuple2<>((Integer) element._1(), element._2()); + } + }); + + // popularity scores + JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { + @Override + public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception { + Integer userIndex = userIndexMap.get(buyEvent.getUser()); + Integer itemIndex = itemIndexMap.get(buyEvent.getItem()); + + return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); + } + }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { + @Override + public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { + return (element != null); + } + }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { + return new Tuple2<>(element._1()._2(), element._2()); + } + }).reduceByKey(new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer integer, Integer integer2) throws Exception { + return integer + integer2; + } + }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() { + @Override + public ItemScore call(Tuple2<Integer, Integer> element) throws Exception { + return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue()); + } + }); + + JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD); + + return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap()); + } + + @Override + public PredictedResult predict(Model model, final Query query) { + final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { + @Override + public Boolean call(Tuple2<String, Integer> userIndex) throws Exception { + return userIndex._1().equals(query.getUserEntityId()); + } + }); + + double[] userFeature = null; + if (!matchedUser.isEmpty()) { + final Integer matchedUserIndex = matchedUser.first()._2(); + userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() { + @Override + public Boolean call(Tuple2<Integer, double[]> element) throws Exception { + return element._1().equals(matchedUserIndex); + } + }).first()._2(); + } + + if (userFeature != null) { + return new PredictedResult(topItemsForUser(userFeature, model, query)); + } else { + List<double[]> recentProductFeatures = getRecentProductFeatures(query, model); + if (recentProductFeatures.isEmpty()) { + return new PredictedResult(mostPopularItems(model, query)); + } else { + return new PredictedResult(similarItems(recentProductFeatures, model, query)); + } + } + } + + @Override + public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) { + List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect(); + List<Tuple2<Object, PredictedResult>> results = new ArrayList<>(); + + for (Tuple2<Object, Query> indexQuery : indexQueries) { + results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2()))); + } + + return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd(); + } + + private List<double[]> getRecentProductFeatures(Query query, Model model) { + try { + List<double[]> result = new ArrayList<>(); + + List<Event> events = LJavaEventStore.findByEntity( + ap.getAppName(), + "user", + query.getUserEntityId(), + OptionHelper.<String>none(), + OptionHelper.some(ap.getSimilarItemEvents()), + OptionHelper.some(OptionHelper.some("item")), + OptionHelper.<Option<String>>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.some(10), + true, + Duration.apply(10, TimeUnit.SECONDS)); + + for (final Event event : events) { + if (event.targetEntityId().isDefined()) { + JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { + @Override + public Boolean call(Tuple2<String, Integer> element) throws Exception { + return element._1().equals(event.targetEntityId().get()); + } + }); + + final Integer itemIndex = filtered.first()._2(); + + if (!filtered.isEmpty()) { + + JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() { + @Override + public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { + return itemIndex.equals(element._1()); + } + }); + + List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect(); + if (oneIndexItemFeatures.size() > 0) { + result.add(oneIndexItemFeatures.get(0)._2()._2()); + } + } + } + } + + return result; + } catch (Exception e) { + logger.error("Error reading recent events for user " + query.getUserEntityId()); + throw new RuntimeException(e.getMessage(), e); + } + } + + private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) { + final DoubleMatrix userMatrix = new DoubleMatrix(userFeature); + + JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { + @Override + public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { + return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2()))); + } + }); + + itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); + return sortAndTake(itemScores, query.getNumber()); + } + + private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) { + JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { + @Override + public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { + double similarity = 0.0; + for (double[] recentFeature : recentProductFeatures) { + similarity += cosineSimilarity(element._2()._2(), recentFeature); + } + + return new ItemScore(element._2()._1(), similarity); + } + }); + + itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); + return sortAndTake(itemScores, query.getNumber()); + } + + private List<ItemScore> mostPopularItems(Model model, Query query) { + JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); + return sortAndTake(itemScores, query.getNumber()); + } + + private double cosineSimilarity(double[] a, double[] b) { + DoubleMatrix matrixA = new DoubleMatrix(a); + DoubleMatrix matrixB = new DoubleMatrix(b); + + return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2()); + } + + private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) { + return all.sortBy(new Function<ItemScore, Double>() { + @Override + public Double call(ItemScore itemScore) throws Exception { + return itemScore.getScore(); + } + }, false, all.partitions().size()).take(number); + } + + private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) { + final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId); + final Set<String> unavailableItemEntityIds = unavailableItemEntityIds(); + + return all.filter(new Function<ItemScore, Boolean>() { + @Override + public Boolean call(ItemScore itemScore) throws Exception { + Item item = items.get(itemScore.getItemEntityId()); + + return (item != null + && passWhitelistCriteria(whitelist, item.getEntityId()) + && passBlacklistCriteria(blacklist, item.getEntityId()) + && passCategoryCriteria(categories, item) + && passUnseenCriteria(seenItemEntityIds, item.getEntityId()) + && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId())); + } + }); + } + + private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) { + return (whitelist.isEmpty() || whitelist.contains(itemEntityId)); + } + + private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) { + return !blacklist.contains(itemEntityId); + } + + private boolean passCategoryCriteria(Set<String> categories, Item item) { + return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0); + } + + private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) { + return !seen.contains(itemEntityId); + } + + private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) { + return !unavailableItemEntityIds.contains(entityId); + } + + private Set<String> unavailableItemEntityIds() { + try { + List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity( + ap.getAppName(), + "constraint", + "unavailableItems", + OptionHelper.<String>none(), + OptionHelper.some(Collections.singletonList("$set")), + OptionHelper.<Option<String>>none(), + OptionHelper.<Option<String>>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.some(1), + true, + Duration.apply(10, TimeUnit.SECONDS)); + + if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet(); + + Event unavailableConstraint = unavailableConstraintEvents.get(0); + + List<String> unavailableItems = unavailableConstraint.properties().getStringList("items"); + + return new HashSet<>(unavailableItems); + } catch (Exception e) { + logger.error("Error reading constraint events"); + throw new RuntimeException(e.getMessage(), e); + } + } + + private Set<String> seenItemEntityIds(String userEntityId) { + if (!ap.isUnseenOnly()) return Collections.emptySet(); + + try { + Set<String> result = new HashSet<>(); + List<Event> seenEvents = LJavaEventStore.findByEntity( + ap.getAppName(), + "user", + userEntityId, + OptionHelper.<String>none(), + OptionHelper.some(ap.getSeenItemEvents()), + OptionHelper.some(OptionHelper.some("item")), + OptionHelper.<Option<String>>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<Integer>none(), + true, + Duration.apply(10, TimeUnit.SECONDS)); + + for (Event event : seenEvents) { + result.add(event.targetEntityId().get()); + } + + return result; + } catch (Exception e) { + logger.error("Error reading seen events for user " + userEntityId); + throw new RuntimeException(e.getMessage(), e); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/AlgorithmParams.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/AlgorithmParams.java b/src/main/java/org/example/recommendation/AlgorithmParams.java new file mode 100644 index 0000000..4b9c7ed --- /dev/null +++ b/src/main/java/org/example/recommendation/AlgorithmParams.java @@ -0,0 +1,74 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.Params; + +import java.util.List; + +public class AlgorithmParams implements Params{ + private final long seed; + private final int rank; + private final int iteration; + private final double lambda; + private final String appName; + private final List<String> similarItemEvents; + private final boolean unseenOnly; + private final List<String> seenItemEvents; + + + public AlgorithmParams(long seed, int rank, int iteration, double lambda, String appName, List<String> similarItemEvents, boolean unseenOnly, List<String> seenItemEvents) { + this.seed = seed; + this.rank = rank; + this.iteration = iteration; + this.lambda = lambda; + this.appName = appName; + this.similarItemEvents = similarItemEvents; + this.unseenOnly = unseenOnly; + this.seenItemEvents = seenItemEvents; + } + + public long getSeed() { + return seed; + } + + public int getRank() { + return rank; + } + + public int getIteration() { + return iteration; + } + + public double getLambda() { + return lambda; + } + + public String getAppName() { + return appName; + } + + public List<String> getSimilarItemEvents() { + return similarItemEvents; + } + + public boolean isUnseenOnly() { + return unseenOnly; + } + + public List<String> getSeenItemEvents() { + return seenItemEvents; + } + + @Override + public String toString() { + return "AlgorithmParams{" + + "seed=" + seed + + ", rank=" + rank + + ", iteration=" + iteration + + ", lambda=" + lambda + + ", appName='" + appName + '\'' + + ", similarItemEvents=" + similarItemEvents + + ", unseenOnly=" + unseenOnly + + ", seenItemEvents=" + seenItemEvents + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/DataSource.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/DataSource.java b/src/main/java/org/example/recommendation/DataSource.java new file mode 100644 index 0000000..90ac975 --- /dev/null +++ b/src/main/java/org/example/recommendation/DataSource.java @@ -0,0 +1,150 @@ +package org.example.recommendation; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.predictionio.controller.EmptyParams; +import org.apache.predictionio.controller.java.PJavaDataSource; +import org.apache.predictionio.data.storage.Event; +import org.apache.predictionio.data.storage.PropertyMap; +import org.apache.predictionio.data.store.java.OptionHelper; +import org.apache.predictionio.data.store.java.PJavaEventStore; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.rdd.RDD; +import org.joda.time.DateTime; +import scala.Option; +import scala.Tuple2; +import scala.Tuple3; +import scala.collection.JavaConversions; +import scala.collection.JavaConversions$; +import scala.collection.Seq; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class DataSource extends PJavaDataSource<TrainingData, EmptyParams, Query, Set<String>> { + + private final DataSourceParams dsp; + + public DataSource(DataSourceParams dsp) { + this.dsp = dsp; + } + + @Override + public TrainingData readTraining(SparkContext sc) { + JavaPairRDD<String,User> usersRDD = PJavaEventStore.aggregateProperties( + dsp.getAppName(), + "user", + OptionHelper.<String>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<List<String>>none(), + sc) + .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, User>() { + @Override + public Tuple2<String, User> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception { + Set<String> keys = JavaConversions$.MODULE$.setAsJavaSet(entityIdProperty._2().keySet()); + Map<String, String> properties = new HashMap<>(); + for (String key : keys) { + properties.put(key, entityIdProperty._2().get(key, String.class)); + } + + User user = new User(entityIdProperty._1(), ImmutableMap.copyOf(properties)); + + return new Tuple2<>(user.getEntityId(), user); + } + }); + + JavaPairRDD<String, Item> itemsRDD = PJavaEventStore.aggregateProperties( + dsp.getAppName(), + "item", + OptionHelper.<String>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<List<String>>none(), + sc) + .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, Item>() { + @Override + public Tuple2<String, Item> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception { + List<String> categories = entityIdProperty._2().getStringList("categories"); + Item item = new Item(entityIdProperty._1(), ImmutableSet.copyOf(categories)); + + return new Tuple2<>(item.getEntityId(), item); + } + }); + + JavaRDD<UserItemEvent> viewEventsRDD = PJavaEventStore.find( + dsp.getAppName(), + OptionHelper.<String>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.some("user"), + OptionHelper.<String>none(), + OptionHelper.some(Collections.singletonList("view")), + OptionHelper.<Option<String>>none(), + OptionHelper.<Option<String>>none(), + sc) + .map(new Function<Event, UserItemEvent>() { + @Override + public UserItemEvent call(Event event) throws Exception { + return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.VIEW); + } + }); + + JavaRDD<UserItemEvent> buyEventsRDD = PJavaEventStore.find( + dsp.getAppName(), + OptionHelper.<String>none(), + OptionHelper.<DateTime>none(), + OptionHelper.<DateTime>none(), + OptionHelper.some("user"), + OptionHelper.<String>none(), + OptionHelper.some(Collections.singletonList("buy")), + OptionHelper.<Option<String>>none(), + OptionHelper.<Option<String>>none(), + sc) + .map(new Function<Event, UserItemEvent>() { + @Override + public UserItemEvent call(Event event) throws Exception { + return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.BUY); + } + }); + + return new TrainingData(usersRDD, itemsRDD, viewEventsRDD, buyEventsRDD); + } + + @Override + public Seq<Tuple3<TrainingData, EmptyParams, RDD<Tuple2<Query, Set<String>>>>> readEval(SparkContext sc) { + TrainingData all = readTraining(sc); + double[] split = {0.5, 0.5}; + JavaRDD<UserItemEvent>[] trainingAndTestingViews = all.getViewEvents().randomSplit(split, 1); + JavaRDD<UserItemEvent>[] trainingAndTestingBuys = all.getBuyEvents().randomSplit(split, 1); + + RDD<Tuple2<Query, Set<String>>> queryActual = JavaPairRDD.toRDD(trainingAndTestingViews[1].union(trainingAndTestingBuys[1]).groupBy(new Function<UserItemEvent, String>() { + @Override + public String call(UserItemEvent event) throws Exception { + return event.getUser(); + } + }).mapToPair(new PairFunction<Tuple2<String, Iterable<UserItemEvent>>, Query, Set<String>>() { + @Override + public Tuple2<Query, Set<String>> call(Tuple2<String, Iterable<UserItemEvent>> userEvents) throws Exception { + Query query = new Query(userEvents._1(), 10, Collections.<String>emptySet(), Collections.<String>emptySet(), Collections.<String>emptySet()); + Set<String> actualSet = new HashSet<>(); + for (UserItemEvent event : userEvents._2()) { + actualSet.add(event.getItem()); + } + return new Tuple2<>(query, actualSet); + } + })); + + Tuple3<TrainingData, EmptyParams, RDD<Tuple2<Query, Set<String>>>> setData = new Tuple3<>(new TrainingData(all.getUsers(), all.getItems(), trainingAndTestingViews[0], trainingAndTestingBuys[0]), new EmptyParams(), queryActual); + + return JavaConversions.asScalaIterable(Collections.singletonList(setData)).toSeq(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/DataSourceParams.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/DataSourceParams.java b/src/main/java/org/example/recommendation/DataSourceParams.java new file mode 100644 index 0000000..4651b92 --- /dev/null +++ b/src/main/java/org/example/recommendation/DataSourceParams.java @@ -0,0 +1,15 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.Params; + +public class DataSourceParams implements Params{ + private final String appName; + + public DataSourceParams(String appName) { + this.appName = appName; + } + + public String getAppName() { + return appName; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Item.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Item.java b/src/main/java/org/example/recommendation/Item.java new file mode 100644 index 0000000..2159beb --- /dev/null +++ b/src/main/java/org/example/recommendation/Item.java @@ -0,0 +1,31 @@ +package org.example.recommendation; + +import java.io.Serializable; +import java.util.Set; + +public class Item implements Serializable{ + private final Set<String> categories; + private final String entityId; + + public Item(String entityId, Set<String> categories) { + this.categories = categories; + this.entityId = entityId; + } + + public String getEntityId() { + return entityId; + } + + public Set<String> getCategories() { + return categories; + } + + @Override + public String toString() { + return "Item{" + + "categories=" + categories + + ", entityId='" + entityId + '\'' + + '}'; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/ItemScore.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/ItemScore.java b/src/main/java/org/example/recommendation/ItemScore.java new file mode 100644 index 0000000..23c3fdb --- /dev/null +++ b/src/main/java/org/example/recommendation/ItemScore.java @@ -0,0 +1,34 @@ +package org.example.recommendation; + +import java.io.Serializable; + +public class ItemScore implements Serializable, Comparable<ItemScore> { + private final String itemEntityId; + private final double score; + + public ItemScore(String itemEntityId, double score) { + this.itemEntityId = itemEntityId; + this.score = score; + } + + public String getItemEntityId() { + return itemEntityId; + } + + public double getScore() { + return score; + } + + @Override + public String toString() { + return "ItemScore{" + + "itemEntityId='" + itemEntityId + '\'' + + ", score=" + score + + '}'; + } + + @Override + public int compareTo(ItemScore o) { + return Double.valueOf(score).compareTo(o.score); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Model.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Model.java b/src/main/java/org/example/recommendation/Model.java new file mode 100644 index 0000000..ebf42e5 --- /dev/null +++ b/src/main/java/org/example/recommendation/Model.java @@ -0,0 +1,84 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.Params; +import org.apache.predictionio.controller.PersistentModel; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Map; + +public class Model implements Serializable, PersistentModel<AlgorithmParams> { + private static final Logger logger = LoggerFactory.getLogger(Model.class); + private final JavaPairRDD<Integer, double[]> userFeatures; + private final JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures; + private final JavaPairRDD<String, Integer> userIndex; + private final JavaPairRDD<String, Integer> itemIndex; + private final JavaRDD<ItemScore> itemPopularityScore; + private final Map<String, Item> items; + + public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, Map<String, Item> items) { + this.userFeatures = userFeatures; + this.indexItemFeatures = indexItemFeatures; + this.userIndex = userIndex; + this.itemIndex = itemIndex; + this.itemPopularityScore = itemPopularityScore; + this.items = items; + } + + public JavaPairRDD<Integer, double[]> getUserFeatures() { + return userFeatures; + } + + public JavaPairRDD<Integer, Tuple2<String, double[]>> getIndexItemFeatures() { + return indexItemFeatures; + } + + public JavaPairRDD<String, Integer> getUserIndex() { + return userIndex; + } + + public JavaPairRDD<String, Integer> getItemIndex() { + return itemIndex; + } + + public JavaRDD<ItemScore> getItemPopularityScore() { + return itemPopularityScore; + } + + public Map<String, Item> getItems() { + return items; + } + + @Override + public boolean save(String id, AlgorithmParams params, SparkContext sc) { + userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures"); + indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures"); + userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex"); + itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex"); + itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore"); + new JavaSparkContext(sc).parallelize(Collections.singletonList(items)).saveAsObjectFile("/tmp/" + id + "/items"); + + logger.info("Saved model to /tmp/" + id); + return true; + } + + public static Model load(String id, Params params, SparkContext sc) { + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); + JavaPairRDD<Integer, double[]> userFeatures = JavaPairRDD.<Integer, double[]>fromJavaRDD(jsc.<Tuple2<Integer, double[]>>objectFile("/tmp/" + id + "/userFeatures")); + JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = JavaPairRDD.<Integer, Tuple2<String, double[]>>fromJavaRDD(jsc.<Tuple2<Integer, Tuple2<String, double[]>>>objectFile("/tmp/" + id + "/indexItemFeatures")); + JavaPairRDD<String, Integer> userIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex")); + JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex")); + JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore"); + Map<String, Item> items = jsc.<Map<String, Item>>objectFile("/tmp/" + id + "/items").collect().get(0); + + logger.info("loaded model"); + return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/PredictedResult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/PredictedResult.java b/src/main/java/org/example/recommendation/PredictedResult.java new file mode 100644 index 0000000..54d7ade --- /dev/null +++ b/src/main/java/org/example/recommendation/PredictedResult.java @@ -0,0 +1,23 @@ +package org.example.recommendation; + +import java.io.Serializable; +import java.util.List; + +public class PredictedResult implements Serializable{ + private final List<ItemScore> itemScores; + + public PredictedResult(List<ItemScore> itemScores) { + this.itemScores = itemScores; + } + + public List<ItemScore> getItemScores() { + return itemScores; + } + + @Override + public String toString() { + return "PredictedResult{" + + "itemScores=" + itemScores + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Preparator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Preparator.java b/src/main/java/org/example/recommendation/Preparator.java new file mode 100644 index 0000000..33beb50 --- /dev/null +++ b/src/main/java/org/example/recommendation/Preparator.java @@ -0,0 +1,12 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.java.PJavaPreparator; +import org.apache.spark.SparkContext; + +public class Preparator extends PJavaPreparator<TrainingData, PreparedData> { + + @Override + public PreparedData prepare(SparkContext sc, TrainingData trainingData) { + return new PreparedData(trainingData); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/PreparedData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/PreparedData.java b/src/main/java/org/example/recommendation/PreparedData.java new file mode 100644 index 0000000..802b7f2 --- /dev/null +++ b/src/main/java/org/example/recommendation/PreparedData.java @@ -0,0 +1,15 @@ +package org.example.recommendation; + +import java.io.Serializable; + +public class PreparedData implements Serializable { + private final TrainingData trainingData; + + public PreparedData(TrainingData trainingData) { + this.trainingData = trainingData; + } + + public TrainingData getTrainingData() { + return trainingData; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Query.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Query.java b/src/main/java/org/example/recommendation/Query.java new file mode 100644 index 0000000..977f566 --- /dev/null +++ b/src/main/java/org/example/recommendation/Query.java @@ -0,0 +1,55 @@ +package org.example.recommendation; + +import java.io.Serializable; +import java.util.Collections; +import java.util.Set; + +public class Query implements Serializable{ + private final String userEntityId; + private final int number; + private final Set<String> categories; + private final Set<String> whitelist; + private final Set<String> blacklist; + + public Query(String userEntityId, int number, Set<String> categories, Set<String> whitelist, Set<String> blacklist) { + this.userEntityId = userEntityId; + this.number = number; + this.categories = categories; + this.whitelist = whitelist; + this.blacklist = blacklist; + } + + public String getUserEntityId() { + return userEntityId; + } + + public int getNumber() { + return number; + } + + public Set<String> getCategories() { + if (categories == null) return Collections.emptySet(); + return categories; + } + + public Set<String> getWhitelist() { + if (whitelist == null) return Collections.emptySet(); + return whitelist; + } + + public Set<String> getBlacklist() { + if (blacklist == null) return Collections.emptySet(); + return blacklist; + } + + @Override + public String toString() { + return "Query{" + + "userEntityId='" + userEntityId + '\'' + + ", number=" + number + + ", categories=" + categories + + ", whitelist=" + whitelist + + ", blacklist=" + blacklist + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/RecommendationEngine.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/RecommendationEngine.java b/src/main/java/org/example/recommendation/RecommendationEngine.java new file mode 100644 index 0000000..ead9aa7 --- /dev/null +++ b/src/main/java/org/example/recommendation/RecommendationEngine.java @@ -0,0 +1,23 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.EmptyParams; +import org.apache.predictionio.controller.Engine; +import org.apache.predictionio.controller.EngineFactory; +import org.apache.predictionio.core.BaseAlgorithm; +import org.apache.predictionio.core.BaseEngine; + +import java.util.Collections; +import java.util.Set; + +public class RecommendationEngine extends EngineFactory { + + @Override + public BaseEngine<EmptyParams, Query, PredictedResult, Set<String>> apply() { + return new Engine<>( + DataSource.class, + Preparator.class, + Collections.<String, Class<? extends BaseAlgorithm<PreparedData, ?, Query, PredictedResult>>>singletonMap("algo", Algorithm.class), + Serving.class + ); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Serving.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/Serving.java b/src/main/java/org/example/recommendation/Serving.java new file mode 100644 index 0000000..80d6c83 --- /dev/null +++ b/src/main/java/org/example/recommendation/Serving.java @@ -0,0 +1,12 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.java.LJavaServing; +import scala.collection.Seq; + +public class Serving extends LJavaServing<Query, PredictedResult> { + + @Override + public PredictedResult serve(Query query, Seq<PredictedResult> predictions) { + return predictions.head(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/TrainingData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/TrainingData.java b/src/main/java/org/example/recommendation/TrainingData.java new file mode 100644 index 0000000..35af8a0 --- /dev/null +++ b/src/main/java/org/example/recommendation/TrainingData.java @@ -0,0 +1,50 @@ +package org.example.recommendation; + +import org.apache.predictionio.controller.SanityCheck; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; + +import java.io.Serializable; + +public class TrainingData implements Serializable, SanityCheck { + private final JavaPairRDD<String, User> users; + private final JavaPairRDD<String, Item> items; + private final JavaRDD<UserItemEvent> viewEvents; + private final JavaRDD<UserItemEvent> buyEvents; + + public TrainingData(JavaPairRDD<String, User> users, JavaPairRDD<String, Item> items, JavaRDD<UserItemEvent> viewEvents, JavaRDD<UserItemEvent> buyEvents) { + this.users = users; + this.items = items; + this.viewEvents = viewEvents; + this.buyEvents = buyEvents; + } + + public JavaPairRDD<String, User> getUsers() { + return users; + } + + public JavaPairRDD<String, Item> getItems() { + return items; + } + + public JavaRDD<UserItemEvent> getViewEvents() { + return viewEvents; + } + + public JavaRDD<UserItemEvent> getBuyEvents() { + return buyEvents; + } + + @Override + public void sanityCheck() { + if (users.isEmpty()) { + throw new AssertionError("User data is empty"); + } + if (items.isEmpty()) { + throw new AssertionError("Item data is empty"); + } + if (viewEvents.isEmpty()) { + throw new AssertionError("View Event data is empty"); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/User.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/User.java b/src/main/java/org/example/recommendation/User.java new file mode 100644 index 0000000..d187a20 --- /dev/null +++ b/src/main/java/org/example/recommendation/User.java @@ -0,0 +1,30 @@ +package org.example.recommendation; + +import java.io.Serializable; +import java.util.Map; + +public class User implements Serializable { + private final String entityId; + private final Map<String, String> properties; + + public User(String entityId, Map<String, String> properties) { + this.entityId = entityId; + this.properties = properties; + } + + public String getEntityId() { + return entityId; + } + + public Map<String, String> getProperties() { + return properties; + } + + @Override + public String toString() { + return "User{" + + "entityId='" + entityId + '\'' + + ", properties=" + properties + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/UserItemEvent.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/UserItemEvent.java b/src/main/java/org/example/recommendation/UserItemEvent.java new file mode 100644 index 0000000..d548a18 --- /dev/null +++ b/src/main/java/org/example/recommendation/UserItemEvent.java @@ -0,0 +1,43 @@ +package org.example.recommendation; + +import java.io.Serializable; + +public class UserItemEvent implements Serializable { + private final String user; + private final String item; + private final long time; + private final UserItemEventType type; + + public UserItemEvent(String user, String item, long time, UserItemEventType type) { + this.user = user; + this.item = item; + this.time = time; + this.type = type; + } + + public String getUser() { + return user; + } + + public String getItem() { + return item; + } + + public long getTime() { + return time; + } + + public UserItemEventType getType() { + return type; + } + + @Override + public String toString() { + return "UserItemEvent{" + + "user='" + user + '\'' + + ", item='" + item + '\'' + + ", time=" + time + + ", type=" + type + + '}'; + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/UserItemEventType.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/UserItemEventType.java b/src/main/java/org/example/recommendation/UserItemEventType.java new file mode 100644 index 0000000..f86b411 --- /dev/null +++ b/src/main/java/org/example/recommendation/UserItemEventType.java @@ -0,0 +1,5 @@ +package org.example.recommendation; + +public enum UserItemEventType { + VIEW, BUY +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java b/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java new file mode 100644 index 0000000..33028eb --- /dev/null +++ b/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java @@ -0,0 +1,28 @@ +package org.example.recommendation.evaluation; + +import org.apache.predictionio.controller.EmptyParams; +import org.apache.predictionio.controller.EngineParams; +import org.apache.predictionio.controller.java.JavaEngineParamsGenerator; +import org.example.recommendation.AlgorithmParams; +import org.example.recommendation.DataSourceParams; + +import java.util.Arrays; +import java.util.Collections; + +public class EvaluationParameter extends JavaEngineParamsGenerator { + public EvaluationParameter() { + this.setEngineParamsList( + Collections.singletonList( + new EngineParams( + "", + new DataSourceParams("javadase"), + "", + new EmptyParams(), + Collections.singletonMap("algo", new AlgorithmParams(1, 10, 10, 0.01, "javadase", Collections.singletonList("view"), true, Arrays.asList("buy", "view"))), + "", + new EmptyParams() + ) + ) + ); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java b/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java new file mode 100644 index 0000000..2bafc7b --- /dev/null +++ b/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java @@ -0,0 +1,28 @@ +package org.example.recommendation.evaluation; + +import org.apache.predictionio.controller.Engine; +import org.apache.predictionio.controller.java.JavaEvaluation; +import org.apache.predictionio.core.BaseAlgorithm; +import org.example.recommendation.Algorithm; +import org.example.recommendation.DataSource; +import org.example.recommendation.PredictedResult; +import org.example.recommendation.Preparator; +import org.example.recommendation.PreparedData; +import org.example.recommendation.Query; +import org.example.recommendation.Serving; + +import java.util.Collections; + +public class EvaluationSpec extends JavaEvaluation { + public EvaluationSpec() { + this.setEngineMetric( + new Engine<>( + DataSource.class, + Preparator.class, + Collections.<String, Class<? extends BaseAlgorithm<PreparedData, ?, Query, PredictedResult>>>singletonMap("algo", Algorithm.class), + Serving.class + ), + new PrecisionMetric() + ); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java b/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java new file mode 100644 index 0000000..e412fd5 --- /dev/null +++ b/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java @@ -0,0 +1,62 @@ +package org.example.recommendation.evaluation; + +import org.apache.predictionio.controller.EmptyParams; +import org.apache.predictionio.controller.Metric; +import org.apache.predictionio.controller.java.SerializableComparator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.rdd.RDD; +import org.example.recommendation.ItemScore; +import org.example.recommendation.PredictedResult; +import org.example.recommendation.Query; +import scala.Tuple2; +import scala.Tuple3; +import scala.collection.JavaConversions; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class PrecisionMetric extends Metric<EmptyParams, Query, PredictedResult, Set<String>, Double> { + + private static final class MetricComparator implements SerializableComparator<Double> { + @Override + public int compare(Double o1, Double o2) { + return o1.compareTo(o2); + } + } + + public PrecisionMetric() { + super(new MetricComparator()); + } + + @Override + public Double calculate(SparkContext sc, Seq<Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>>> qpas) { + List<Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>>> sets = JavaConversions.asJavaList(qpas); + List<Double> allSetResults = new ArrayList<>(); + + for (Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>> set : sets) { + List<Double> setResults = set._2().toJavaRDD().map(new Function<Tuple3<Query, PredictedResult, Set<String>>, Double>() { + @Override + public Double call(Tuple3<Query, PredictedResult, Set<String>> qpa) throws Exception { + Set<String> predicted = new HashSet<>(); + for (ItemScore itemScore : qpa._2().getItemScores()) { + predicted.add(itemScore.getItemEntityId()); + } + Set<String> intersection = new HashSet<>(predicted); + intersection.retainAll(qpa._3()); + + return 1.0 * intersection.size() / qpa._2().getItemScores().size(); + } + }).collect(); + + allSetResults.addAll(setResults); + } + double sum = 0.0; + for (Double value : allSetResults) sum += value; + + return sum / allSetResults.size(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/template/recommendation/Algorithm.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/template/recommendation/Algorithm.java b/src/main/java/org/template/recommendation/Algorithm.java deleted file mode 100644 index 24b4e5c..0000000 --- a/src/main/java/org/template/recommendation/Algorithm.java +++ /dev/null @@ -1,409 +0,0 @@ -package org.template.recommendation; - -import com.google.common.collect.Sets; -import io.prediction.controller.java.PJavaAlgorithm; -import io.prediction.data.storage.Event; -import io.prediction.data.store.java.LJavaEventStore; -import io.prediction.data.store.java.OptionHelper; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.rdd.RDD; -import org.jblas.DoubleMatrix; -import org.joda.time.DateTime; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import scala.Option; -import scala.Tuple2; -import scala.concurrent.duration.Duration; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> { - - private static final Logger logger = LoggerFactory.getLogger(Algorithm.class); - private final AlgorithmParams ap; - - public Algorithm(AlgorithmParams ap) { - this.ap = ap; - } - - @Override - public Model train(SparkContext sc, PreparedData preparedData) { - TrainingData data = preparedData.getTrainingData(); - - // user stuff - JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() { - @Override - public String call(Tuple2<String, User> idUser) throws Exception { - return idUser._1(); - } - }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { - @Override - public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { - return new Tuple2<>(element._1(), element._2().intValue()); - } - }); - final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap(); - - // item stuff - JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() { - @Override - public String call(Tuple2<String, Item> idItem) throws Exception { - return idItem._1(); - } - }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { - @Override - public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { - return new Tuple2<>(element._1(), element._2().intValue()); - } - }); - final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap(); - JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() { - @Override - public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception { - return element.swap(); - } - }); - final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap(); - - // ratings stuff - JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { - @Override - public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception { - Integer userIndex = userIndexMap.get(viewEvent.getUser()); - Integer itemIndex = itemIndexMap.get(viewEvent.getItem()); - - return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); - } - }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { - @Override - public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { - return (element != null); - } - }).reduceByKey(new Function2<Integer, Integer, Integer>() { - @Override - public Integer call(Integer integer, Integer integer2) throws Exception { - return integer + integer2; - } - }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() { - @Override - public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception { - return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue()); - } - }); - - if (ratings.isEmpty()) - throw new AssertionError("Please check if your events contain valid user and item ID."); - - // MLlib ALS stuff - MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed()); - JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { - @Override - public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { - return new Tuple2<>((Integer) element._1(), element._2()); - } - }); - JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { - @Override - public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { - return new Tuple2<>((Integer) element._1(), element._2()); - } - }); - - // popularity scores - JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { - @Override - public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception { - Integer userIndex = userIndexMap.get(buyEvent.getUser()); - Integer itemIndex = itemIndexMap.get(buyEvent.getItem()); - - return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); - } - }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { - @Override - public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { - return (element != null); - } - }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() { - @Override - public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { - return new Tuple2<>(element._1()._2(), element._2()); - } - }).reduceByKey(new Function2<Integer, Integer, Integer>() { - @Override - public Integer call(Integer integer, Integer integer2) throws Exception { - return integer + integer2; - } - }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() { - @Override - public ItemScore call(Tuple2<Integer, Integer> element) throws Exception { - return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue()); - } - }); - - JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD); - - return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap()); - } - - @Override - public PredictedResult predict(Model model, final Query query) { - final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { - @Override - public Boolean call(Tuple2<String, Integer> userIndex) throws Exception { - return userIndex._1().equals(query.getUserEntityId()); - } - }); - - double[] userFeature = null; - if (!matchedUser.isEmpty()) { - final Integer matchedUserIndex = matchedUser.first()._2(); - userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() { - @Override - public Boolean call(Tuple2<Integer, double[]> element) throws Exception { - return element._1().equals(matchedUserIndex); - } - }).first()._2(); - } - - if (userFeature != null) { - return new PredictedResult(topItemsForUser(userFeature, model, query)); - } else { - List<double[]> recentProductFeatures = getRecentProductFeatures(query, model); - if (recentProductFeatures.isEmpty()) { - return new PredictedResult(mostPopularItems(model, query)); - } else { - return new PredictedResult(similarItems(recentProductFeatures, model, query)); - } - } - } - - @Override - public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) { - List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect(); - List<Tuple2<Object, PredictedResult>> results = new ArrayList<>(); - - for (Tuple2<Object, Query> indexQuery : indexQueries) { - results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2()))); - } - - return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd(); - } - - private List<double[]> getRecentProductFeatures(Query query, Model model) { - try { - List<double[]> result = new ArrayList<>(); - - List<Event> events = LJavaEventStore.findByEntity( - ap.getAppName(), - "user", - query.getUserEntityId(), - OptionHelper.<String>none(), - OptionHelper.some(ap.getSimilarItemEvents()), - OptionHelper.some(OptionHelper.some("item")), - OptionHelper.<Option<String>>none(), - OptionHelper.<DateTime>none(), - OptionHelper.<DateTime>none(), - OptionHelper.some(10), - true, - Duration.apply(10, TimeUnit.SECONDS)); - - for (final Event event : events) { - if (event.targetEntityId().isDefined()) { - JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { - @Override - public Boolean call(Tuple2<String, Integer> element) throws Exception { - return element._1().equals(event.targetEntityId().get()); - } - }); - - final Integer itemIndex = filtered.first()._2(); - - if (!filtered.isEmpty()) { - - JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() { - @Override - public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { - return itemIndex.equals(element._1()); - } - }); - - List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect(); - if (oneIndexItemFeatures.size() > 0) { - result.add(oneIndexItemFeatures.get(0)._2()._2()); - } - } - } - } - - return result; - } catch (Exception e) { - logger.error("Error reading recent events for user " + query.getUserEntityId()); - throw new RuntimeException(e.getMessage(), e); - } - } - - private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) { - final DoubleMatrix userMatrix = new DoubleMatrix(userFeature); - - JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { - @Override - public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { - return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2()))); - } - }); - - itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); - return sortAndTake(itemScores, query.getNumber()); - } - - private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) { - JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { - @Override - public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { - double similarity = 0.0; - for (double[] recentFeature : recentProductFeatures) { - similarity += cosineSimilarity(element._2()._2(), recentFeature); - } - - return new ItemScore(element._2()._1(), similarity); - } - }); - - itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); - return sortAndTake(itemScores, query.getNumber()); - } - - private List<ItemScore> mostPopularItems(Model model, Query query) { - JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); - return sortAndTake(itemScores, query.getNumber()); - } - - private double cosineSimilarity(double[] a, double[] b) { - DoubleMatrix matrixA = new DoubleMatrix(a); - DoubleMatrix matrixB = new DoubleMatrix(b); - - return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2()); - } - - private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) { - return all.sortBy(new Function<ItemScore, Double>() { - @Override - public Double call(ItemScore itemScore) throws Exception { - return itemScore.getScore(); - } - }, false, all.partitions().size()).take(number); - } - - private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) { - final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId); - final Set<String> unavailableItemEntityIds = unavailableItemEntityIds(); - - return all.filter(new Function<ItemScore, Boolean>() { - @Override - public Boolean call(ItemScore itemScore) throws Exception { - Item item = items.get(itemScore.getItemEntityId()); - - return (item != null - && passWhitelistCriteria(whitelist, item.getEntityId()) - && passBlacklistCriteria(blacklist, item.getEntityId()) - && passCategoryCriteria(categories, item) - && passUnseenCriteria(seenItemEntityIds, item.getEntityId()) - && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId())); - } - }); - } - - private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) { - return (whitelist.isEmpty() || whitelist.contains(itemEntityId)); - } - - private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) { - return !blacklist.contains(itemEntityId); - } - - private boolean passCategoryCriteria(Set<String> categories, Item item) { - return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0); - } - - private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) { - return !seen.contains(itemEntityId); - } - - private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) { - return !unavailableItemEntityIds.contains(entityId); - } - - private Set<String> unavailableItemEntityIds() { - try { - List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity( - ap.getAppName(), - "constraint", - "unavailableItems", - OptionHelper.<String>none(), - OptionHelper.some(Collections.singletonList("$set")), - OptionHelper.<Option<String>>none(), - OptionHelper.<Option<String>>none(), - OptionHelper.<DateTime>none(), - OptionHelper.<DateTime>none(), - OptionHelper.some(1), - true, - Duration.apply(10, TimeUnit.SECONDS)); - - if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet(); - - Event unavailableConstraint = unavailableConstraintEvents.get(0); - - List<String> unavailableItems = unavailableConstraint.properties().getStringList("items"); - - return new HashSet<>(unavailableItems); - } catch (Exception e) { - logger.error("Error reading constraint events"); - throw new RuntimeException(e.getMessage(), e); - } - } - - private Set<String> seenItemEntityIds(String userEntityId) { - if (!ap.isUnseenOnly()) return Collections.emptySet(); - - try { - Set<String> result = new HashSet<>(); - List<Event> seenEvents = LJavaEventStore.findByEntity( - ap.getAppName(), - "user", - userEntityId, - OptionHelper.<String>none(), - OptionHelper.some(ap.getSeenItemEvents()), - OptionHelper.some(OptionHelper.some("item")), - OptionHelper.<Option<String>>none(), - OptionHelper.<DateTime>none(), - OptionHelper.<DateTime>none(), - OptionHelper.<Integer>none(), - true, - Duration.apply(10, TimeUnit.SECONDS)); - - for (Event event : seenEvents) { - result.add(event.targetEntityId().get()); - } - - return result; - } catch (Exception e) { - logger.error("Error reading seen events for user " + userEntityId); - throw new RuntimeException(e.getMessage(), e); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/template/recommendation/AlgorithmParams.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/template/recommendation/AlgorithmParams.java b/src/main/java/org/template/recommendation/AlgorithmParams.java deleted file mode 100644 index 0466334..0000000 --- a/src/main/java/org/template/recommendation/AlgorithmParams.java +++ /dev/null @@ -1,74 +0,0 @@ -package org.template.recommendation; - -import io.prediction.controller.Params; - -import java.util.List; - -public class AlgorithmParams implements Params{ - private final long seed; - private final int rank; - private final int iteration; - private final double lambda; - private final String appName; - private final List<String> similarItemEvents; - private final boolean unseenOnly; - private final List<String> seenItemEvents; - - - public AlgorithmParams(long seed, int rank, int iteration, double lambda, String appName, List<String> similarItemEvents, boolean unseenOnly, List<String> seenItemEvents) { - this.seed = seed; - this.rank = rank; - this.iteration = iteration; - this.lambda = lambda; - this.appName = appName; - this.similarItemEvents = similarItemEvents; - this.unseenOnly = unseenOnly; - this.seenItemEvents = seenItemEvents; - } - - public long getSeed() { - return seed; - } - - public int getRank() { - return rank; - } - - public int getIteration() { - return iteration; - } - - public double getLambda() { - return lambda; - } - - public String getAppName() { - return appName; - } - - public List<String> getSimilarItemEvents() { - return similarItemEvents; - } - - public boolean isUnseenOnly() { - return unseenOnly; - } - - public List<String> getSeenItemEvents() { - return seenItemEvents; - } - - @Override - public String toString() { - return "AlgorithmParams{" + - "seed=" + seed + - ", rank=" + rank + - ", iteration=" + iteration + - ", lambda=" + lambda + - ", appName='" + appName + '\'' + - ", similarItemEvents=" + similarItemEvents + - ", unseenOnly=" + unseenOnly + - ", seenItemEvents=" + seenItemEvents + - '}'; - } -}
