You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@predictionio.apache.org by do...@apache.org on 2017/05/04 18:49:41 UTC
[2/2] incubator-predictionio-template-java-ecom-recommender git
commit: 0.11.0-incubating release
0.11.0-incubating release
Project: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/commit/36995dfc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/tree/36995dfc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/diff/36995dfc
Branch: refs/heads/master
Commit: 36995dfce7581cb459456858651ebd2f846d62b6
Parents: ae23d8c
Author: Donald Szeto <do...@apache.org>
Authored: Thu May 4 11:48:55 2017 -0700
Committer: Donald Szeto <do...@apache.org>
Committed: Thu May 4 11:48:55 2017 -0700
----------------------------------------------------------------------
README.md | 19 +-
build.sbt | 20 +-
engine.json | 2 +-
project/assembly.sbt | 2 +-
project/build.properties | 1 +
project/pio-build.sbt | 1 -
.../org/example/recommendation/Algorithm.java | 409 +++++++++++++++++++
.../example/recommendation/AlgorithmParams.java | 74 ++++
.../org/example/recommendation/DataSource.java | 150 +++++++
.../recommendation/DataSourceParams.java | 15 +
.../java/org/example/recommendation/Item.java | 31 ++
.../org/example/recommendation/ItemScore.java | 34 ++
.../java/org/example/recommendation/Model.java | 84 ++++
.../example/recommendation/PredictedResult.java | 23 ++
.../org/example/recommendation/Preparator.java | 12 +
.../example/recommendation/PreparedData.java | 15 +
.../java/org/example/recommendation/Query.java | 55 +++
.../recommendation/RecommendationEngine.java | 23 ++
.../org/example/recommendation/Serving.java | 12 +
.../example/recommendation/TrainingData.java | 50 +++
.../java/org/example/recommendation/User.java | 30 ++
.../example/recommendation/UserItemEvent.java | 43 ++
.../recommendation/UserItemEventType.java | 5 +
.../evaluation/EvaluationParameter.java | 28 ++
.../evaluation/EvaluationSpec.java | 28 ++
.../evaluation/PrecisionMetric.java | 62 +++
.../org/template/recommendation/Algorithm.java | 409 -------------------
.../recommendation/AlgorithmParams.java | 74 ----
.../org/template/recommendation/DataSource.java | 150 -------
.../recommendation/DataSourceParams.java | 15 -
.../java/org/template/recommendation/Item.java | 31 --
.../org/template/recommendation/ItemScore.java | 34 --
.../java/org/template/recommendation/Model.java | 84 ----
.../recommendation/PredictedResult.java | 23 --
.../org/template/recommendation/Preparator.java | 12 -
.../template/recommendation/PreparedData.java | 15 -
.../java/org/template/recommendation/Query.java | 55 ---
.../recommendation/RecommendationEngine.java | 23 --
.../org/template/recommendation/Serving.java | 12 -
.../template/recommendation/TrainingData.java | 50 ---
.../java/org/template/recommendation/User.java | 30 --
.../template/recommendation/UserItemEvent.java | 43 --
.../recommendation/UserItemEventType.java | 5 -
.../evaluation/EvaluationParameter.java | 28 --
.../evaluation/EvaluationSpec.java | 28 --
.../evaluation/PrecisionMetric.java | 62 ---
template.json | 2 +-
47 files changed, 1206 insertions(+), 1207 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/README.md
----------------------------------------------------------------------
diff --git a/README.md b/README.md
index a4ac665..35df67d 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,18 @@
-# E-Commerce Recommendation Template
+# E-Commerce Recommendation Template in Java
## Documentation
-Please refer to http://docs.prediction.io/templates/javaecommercerecommendation/quickstart/
+Please refer to
+http://predictionio.incubator.apache.org/templates/javaecommercerecommendation/quickstart/.
## Versions
+### v0.11.0-incubating
+
+- Update to build with PredictionIO 0.11.0-incubating
+- Rename Java package name
+- Update SBT and plugin versions
+
### v0.1.2
add "org.jblas" dependency in build.sbt
@@ -19,13 +26,13 @@ Please refer to http://docs.prediction.io/templates/javaecommercerecommendation/
## Development Notes
-### import sample data
+### Import Sample Data
```
$ python data/import_eventserver.py --access_key <your_access_key>
```
-### query
+### Query
normal:
@@ -77,7 +84,7 @@ curl -H "Content-Type: application/json" \
http://localhost:8000/queries.json
```
-### handle new user
+### Handle New User
new user:
@@ -120,7 +127,7 @@ curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \
```
-## handle unavailable items
+### Handle Unavailable Items
Set the following items as unavailable (need to specify complete list each time when this list is changed):
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/build.sbt
----------------------------------------------------------------------
diff --git a/build.sbt b/build.sbt
index 20e1346..36da38f 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,16 +1,8 @@
-import AssemblyKeys._
-
-assemblySettings
-
-name := "barebone-template"
-
-organization := "io.prediction"
+name := "template-java-ecom-recommender"
libraryDependencies ++= Seq(
- "io.prediction" %% "core" % pioVersion.value % "provided",
- "org.apache.spark" %% "spark-core" % "1.3.0" % "provided",
- "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided",
- "org.scalatest" % "scalatest_2.10" % "2.2.1" % "test",
- "com.google.guava" % "guava" % "12.0",
- "org.jblas" % "jblas" % "1.2.4"
-)
+ "org.apache.predictionio" %% "apache-predictionio-core" % "0.11.0-incubating" % "provided",
+ "org.apache.spark" %% "spark-core" % "1.3.0" % "provided",
+ "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided",
+ "com.google.guava" % "guava" % "12.0",
+ "org.jblas" % "jblas" % "1.2.4")
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/engine.json
----------------------------------------------------------------------
diff --git a/engine.json b/engine.json
index 1f8ed0c..0f44544 100644
--- a/engine.json
+++ b/engine.json
@@ -1,7 +1,7 @@
{
"id": "default",
"description": "Default settings",
- "engineFactory": "org.template.recommendation.RecommendationEngine",
+ "engineFactory": "org.example.recommendation.RecommendationEngine",
"datasource": {
"params" : {
"appName": "javadase"
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/assembly.sbt
----------------------------------------------------------------------
diff --git a/project/assembly.sbt b/project/assembly.sbt
index 54c3252..e17409e 100644
--- a/project/assembly.sbt
+++ b/project/assembly.sbt
@@ -1 +1 @@
-addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.4")
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/build.properties
----------------------------------------------------------------------
diff --git a/project/build.properties b/project/build.properties
new file mode 100644
index 0000000..64317fd
--- /dev/null
+++ b/project/build.properties
@@ -0,0 +1 @@
+sbt.version=0.13.15
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/project/pio-build.sbt
----------------------------------------------------------------------
diff --git a/project/pio-build.sbt b/project/pio-build.sbt
deleted file mode 100644
index 878fc0d..0000000
--- a/project/pio-build.sbt
+++ /dev/null
@@ -1 +0,0 @@
-addSbtPlugin("io.prediction" % "pio-build" % "0.9.0")
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Algorithm.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Algorithm.java b/src/main/java/org/example/recommendation/Algorithm.java
new file mode 100644
index 0000000..349e945
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Algorithm.java
@@ -0,0 +1,409 @@
+package org.example.recommendation;
+
+import com.google.common.collect.Sets;
+import org.apache.predictionio.controller.java.PJavaAlgorithm;
+import org.apache.predictionio.data.storage.Event;
+import org.apache.predictionio.data.store.java.LJavaEventStore;
+import org.apache.predictionio.data.store.java.OptionHelper;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.recommendation.ALS;
+import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
+import org.apache.spark.mllib.recommendation.Rating;
+import org.apache.spark.rdd.RDD;
+import org.jblas.DoubleMatrix;
+import org.joda.time.DateTime;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Option;
+import scala.Tuple2;
+import scala.concurrent.duration.Duration;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> {
+
+ private static final Logger logger = LoggerFactory.getLogger(Algorithm.class);
+ private final AlgorithmParams ap;
+
+ public Algorithm(AlgorithmParams ap) {
+ this.ap = ap;
+ }
+
+ @Override
+ public Model train(SparkContext sc, PreparedData preparedData) {
+ TrainingData data = preparedData.getTrainingData();
+
+ // user stuff
+ JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
+ @Override
+ public String call(Tuple2<String, User> idUser) throws Exception {
+ return idUser._1();
+ }
+ }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
+ return new Tuple2<>(element._1(), element._2().intValue());
+ }
+ });
+ final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
+
+ // item stuff
+ JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
+ @Override
+ public String call(Tuple2<String, Item> idItem) throws Exception {
+ return idItem._1();
+ }
+ }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
+ return new Tuple2<>(element._1(), element._2().intValue());
+ }
+ });
+ final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
+ JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
+ return element.swap();
+ }
+ });
+ final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
+
+ // ratings stuff
+ JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
+ @Override
+ public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
+ Integer userIndex = userIndexMap.get(viewEvent.getUser());
+ Integer itemIndex = itemIndexMap.get(viewEvent.getItem());
+
+ return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
+ }
+ }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
+ return (element != null);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer integer, Integer integer2) throws Exception {
+ return integer + integer2;
+ }
+ }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
+ @Override
+ public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
+ return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
+ }
+ });
+
+ if (ratings.isEmpty())
+ throw new AssertionError("Please check if your events contain valid user and item ID.");
+
+ // MLlib ALS stuff
+ MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
+ JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
+ @Override
+ public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
+ return new Tuple2<>((Integer) element._1(), element._2());
+ }
+ });
+ JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
+ @Override
+ public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
+ return new Tuple2<>((Integer) element._1(), element._2());
+ }
+ });
+
+ // popularity scores
+ JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
+ @Override
+ public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
+ Integer userIndex = userIndexMap.get(buyEvent.getUser());
+ Integer itemIndex = itemIndexMap.get(buyEvent.getItem());
+
+ return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
+ }
+ }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
+ return (element != null);
+ }
+ }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
+ return new Tuple2<>(element._1()._2(), element._2());
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer integer, Integer integer2) throws Exception {
+ return integer + integer2;
+ }
+ }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
+ @Override
+ public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
+ return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
+ }
+ });
+
+ JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
+
+ return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap());
+ }
+
+ @Override
+ public PredictedResult predict(Model model, final Query query) {
+ final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
+ return userIndex._1().equals(query.getUserEntityId());
+ }
+ });
+
+ double[] userFeature = null;
+ if (!matchedUser.isEmpty()) {
+ final Integer matchedUserIndex = matchedUser.first()._2();
+ userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
+ return element._1().equals(matchedUserIndex);
+ }
+ }).first()._2();
+ }
+
+ if (userFeature != null) {
+ return new PredictedResult(topItemsForUser(userFeature, model, query));
+ } else {
+ List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);
+ if (recentProductFeatures.isEmpty()) {
+ return new PredictedResult(mostPopularItems(model, query));
+ } else {
+ return new PredictedResult(similarItems(recentProductFeatures, model, query));
+ }
+ }
+ }
+
+ @Override
+ public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) {
+ List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect();
+ List<Tuple2<Object, PredictedResult>> results = new ArrayList<>();
+
+ for (Tuple2<Object, Query> indexQuery : indexQueries) {
+ results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2())));
+ }
+
+ return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd();
+ }
+
+ private List<double[]> getRecentProductFeatures(Query query, Model model) {
+ try {
+ List<double[]> result = new ArrayList<>();
+
+ List<Event> events = LJavaEventStore.findByEntity(
+ ap.getAppName(),
+ "user",
+ query.getUserEntityId(),
+ OptionHelper.<String>none(),
+ OptionHelper.some(ap.getSimilarItemEvents()),
+ OptionHelper.some(OptionHelper.some("item")),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.some(10),
+ true,
+ Duration.apply(10, TimeUnit.SECONDS));
+
+ for (final Event event : events) {
+ if (event.targetEntityId().isDefined()) {
+ JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<String, Integer> element) throws Exception {
+ return element._1().equals(event.targetEntityId().get());
+ }
+ });
+
+ final Integer itemIndex = filtered.first()._2();
+
+ if (!filtered.isEmpty()) {
+
+ JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ return itemIndex.equals(element._1());
+ }
+ });
+
+ List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
+ if (oneIndexItemFeatures.size() > 0) {
+ result.add(oneIndexItemFeatures.get(0)._2()._2());
+ }
+ }
+ }
+ }
+
+ return result;
+ } catch (Exception e) {
+ logger.error("Error reading recent events for user " + query.getUserEntityId());
+ throw new RuntimeException(e.getMessage(), e);
+ }
+ }
+
+ private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
+ final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
+
+ JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ @Override
+ public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
+ }
+ });
+
+ itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
+ return sortAndTake(itemScores, query.getNumber());
+ }
+
+ private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
+ JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ @Override
+ public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ double similarity = 0.0;
+ for (double[] recentFeature : recentProductFeatures) {
+ similarity += cosineSimilarity(element._2()._2(), recentFeature);
+ }
+
+ return new ItemScore(element._2()._1(), similarity);
+ }
+ });
+
+ itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
+ return sortAndTake(itemScores, query.getNumber());
+ }
+
+ private List<ItemScore> mostPopularItems(Model model, Query query) {
+ JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
+ return sortAndTake(itemScores, query.getNumber());
+ }
+
+ private double cosineSimilarity(double[] a, double[] b) {
+ DoubleMatrix matrixA = new DoubleMatrix(a);
+ DoubleMatrix matrixB = new DoubleMatrix(b);
+
+ return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
+ }
+
+ private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) {
+ return all.sortBy(new Function<ItemScore, Double>() {
+ @Override
+ public Double call(ItemScore itemScore) throws Exception {
+ return itemScore.getScore();
+ }
+ }, false, all.partitions().size()).take(number);
+ }
+
+ private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) {
+ final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
+ final Set<String> unavailableItemEntityIds = unavailableItemEntityIds();
+
+ return all.filter(new Function<ItemScore, Boolean>() {
+ @Override
+ public Boolean call(ItemScore itemScore) throws Exception {
+ Item item = items.get(itemScore.getItemEntityId());
+
+ return (item != null
+ && passWhitelistCriteria(whitelist, item.getEntityId())
+ && passBlacklistCriteria(blacklist, item.getEntityId())
+ && passCategoryCriteria(categories, item)
+ && passUnseenCriteria(seenItemEntityIds, item.getEntityId())
+ && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId()));
+ }
+ });
+ }
+
+ private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) {
+ return (whitelist.isEmpty() || whitelist.contains(itemEntityId));
+ }
+
+ private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) {
+ return !blacklist.contains(itemEntityId);
+ }
+
+ private boolean passCategoryCriteria(Set<String> categories, Item item) {
+ return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0);
+ }
+
+ private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) {
+ return !seen.contains(itemEntityId);
+ }
+
+ private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) {
+ return !unavailableItemEntityIds.contains(entityId);
+ }
+
+ private Set<String> unavailableItemEntityIds() {
+ try {
+ List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity(
+ ap.getAppName(),
+ "constraint",
+ "unavailableItems",
+ OptionHelper.<String>none(),
+ OptionHelper.some(Collections.singletonList("$set")),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.some(1),
+ true,
+ Duration.apply(10, TimeUnit.SECONDS));
+
+ if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet();
+
+ Event unavailableConstraint = unavailableConstraintEvents.get(0);
+
+ List<String> unavailableItems = unavailableConstraint.properties().getStringList("items");
+
+ return new HashSet<>(unavailableItems);
+ } catch (Exception e) {
+ logger.error("Error reading constraint events");
+ throw new RuntimeException(e.getMessage(), e);
+ }
+ }
+
+ private Set<String> seenItemEntityIds(String userEntityId) {
+ if (!ap.isUnseenOnly()) return Collections.emptySet();
+
+ try {
+ Set<String> result = new HashSet<>();
+ List<Event> seenEvents = LJavaEventStore.findByEntity(
+ ap.getAppName(),
+ "user",
+ userEntityId,
+ OptionHelper.<String>none(),
+ OptionHelper.some(ap.getSeenItemEvents()),
+ OptionHelper.some(OptionHelper.some("item")),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<Integer>none(),
+ true,
+ Duration.apply(10, TimeUnit.SECONDS));
+
+ for (Event event : seenEvents) {
+ result.add(event.targetEntityId().get());
+ }
+
+ return result;
+ } catch (Exception e) {
+ logger.error("Error reading seen events for user " + userEntityId);
+ throw new RuntimeException(e.getMessage(), e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/AlgorithmParams.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/AlgorithmParams.java b/src/main/java/org/example/recommendation/AlgorithmParams.java
new file mode 100644
index 0000000..4b9c7ed
--- /dev/null
+++ b/src/main/java/org/example/recommendation/AlgorithmParams.java
@@ -0,0 +1,74 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.Params;
+
+import java.util.List;
+
+public class AlgorithmParams implements Params{
+ private final long seed;
+ private final int rank;
+ private final int iteration;
+ private final double lambda;
+ private final String appName;
+ private final List<String> similarItemEvents;
+ private final boolean unseenOnly;
+ private final List<String> seenItemEvents;
+
+
+ public AlgorithmParams(long seed, int rank, int iteration, double lambda, String appName, List<String> similarItemEvents, boolean unseenOnly, List<String> seenItemEvents) {
+ this.seed = seed;
+ this.rank = rank;
+ this.iteration = iteration;
+ this.lambda = lambda;
+ this.appName = appName;
+ this.similarItemEvents = similarItemEvents;
+ this.unseenOnly = unseenOnly;
+ this.seenItemEvents = seenItemEvents;
+ }
+
+ public long getSeed() {
+ return seed;
+ }
+
+ public int getRank() {
+ return rank;
+ }
+
+ public int getIteration() {
+ return iteration;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public String getAppName() {
+ return appName;
+ }
+
+ public List<String> getSimilarItemEvents() {
+ return similarItemEvents;
+ }
+
+ public boolean isUnseenOnly() {
+ return unseenOnly;
+ }
+
+ public List<String> getSeenItemEvents() {
+ return seenItemEvents;
+ }
+
+ @Override
+ public String toString() {
+ return "AlgorithmParams{" +
+ "seed=" + seed +
+ ", rank=" + rank +
+ ", iteration=" + iteration +
+ ", lambda=" + lambda +
+ ", appName='" + appName + '\'' +
+ ", similarItemEvents=" + similarItemEvents +
+ ", unseenOnly=" + unseenOnly +
+ ", seenItemEvents=" + seenItemEvents +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/DataSource.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/DataSource.java b/src/main/java/org/example/recommendation/DataSource.java
new file mode 100644
index 0000000..90ac975
--- /dev/null
+++ b/src/main/java/org/example/recommendation/DataSource.java
@@ -0,0 +1,150 @@
+package org.example.recommendation;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.apache.predictionio.controller.EmptyParams;
+import org.apache.predictionio.controller.java.PJavaDataSource;
+import org.apache.predictionio.data.storage.Event;
+import org.apache.predictionio.data.storage.PropertyMap;
+import org.apache.predictionio.data.store.java.OptionHelper;
+import org.apache.predictionio.data.store.java.PJavaEventStore;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.rdd.RDD;
+import org.joda.time.DateTime;
+import scala.Option;
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.JavaConversions;
+import scala.collection.JavaConversions$;
+import scala.collection.Seq;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class DataSource extends PJavaDataSource<TrainingData, EmptyParams, Query, Set<String>> {
+
+ private final DataSourceParams dsp;
+
+ public DataSource(DataSourceParams dsp) {
+ this.dsp = dsp;
+ }
+
+ @Override
+ public TrainingData readTraining(SparkContext sc) {
+ JavaPairRDD<String,User> usersRDD = PJavaEventStore.aggregateProperties(
+ dsp.getAppName(),
+ "user",
+ OptionHelper.<String>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<List<String>>none(),
+ sc)
+ .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, User>() {
+ @Override
+ public Tuple2<String, User> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception {
+ Set<String> keys = JavaConversions$.MODULE$.setAsJavaSet(entityIdProperty._2().keySet());
+ Map<String, String> properties = new HashMap<>();
+ for (String key : keys) {
+ properties.put(key, entityIdProperty._2().get(key, String.class));
+ }
+
+ User user = new User(entityIdProperty._1(), ImmutableMap.copyOf(properties));
+
+ return new Tuple2<>(user.getEntityId(), user);
+ }
+ });
+
+ JavaPairRDD<String, Item> itemsRDD = PJavaEventStore.aggregateProperties(
+ dsp.getAppName(),
+ "item",
+ OptionHelper.<String>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<List<String>>none(),
+ sc)
+ .mapToPair(new PairFunction<Tuple2<String, PropertyMap>, String, Item>() {
+ @Override
+ public Tuple2<String, Item> call(Tuple2<String, PropertyMap> entityIdProperty) throws Exception {
+ List<String> categories = entityIdProperty._2().getStringList("categories");
+ Item item = new Item(entityIdProperty._1(), ImmutableSet.copyOf(categories));
+
+ return new Tuple2<>(item.getEntityId(), item);
+ }
+ });
+
+ JavaRDD<UserItemEvent> viewEventsRDD = PJavaEventStore.find(
+ dsp.getAppName(),
+ OptionHelper.<String>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.some("user"),
+ OptionHelper.<String>none(),
+ OptionHelper.some(Collections.singletonList("view")),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<Option<String>>none(),
+ sc)
+ .map(new Function<Event, UserItemEvent>() {
+ @Override
+ public UserItemEvent call(Event event) throws Exception {
+ return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.VIEW);
+ }
+ });
+
+ JavaRDD<UserItemEvent> buyEventsRDD = PJavaEventStore.find(
+ dsp.getAppName(),
+ OptionHelper.<String>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.<DateTime>none(),
+ OptionHelper.some("user"),
+ OptionHelper.<String>none(),
+ OptionHelper.some(Collections.singletonList("buy")),
+ OptionHelper.<Option<String>>none(),
+ OptionHelper.<Option<String>>none(),
+ sc)
+ .map(new Function<Event, UserItemEvent>() {
+ @Override
+ public UserItemEvent call(Event event) throws Exception {
+ return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.BUY);
+ }
+ });
+
+ return new TrainingData(usersRDD, itemsRDD, viewEventsRDD, buyEventsRDD);
+ }
+
+ @Override
+ public Seq<Tuple3<TrainingData, EmptyParams, RDD<Tuple2<Query, Set<String>>>>> readEval(SparkContext sc) {
+ TrainingData all = readTraining(sc);
+ double[] split = {0.5, 0.5};
+ JavaRDD<UserItemEvent>[] trainingAndTestingViews = all.getViewEvents().randomSplit(split, 1);
+ JavaRDD<UserItemEvent>[] trainingAndTestingBuys = all.getBuyEvents().randomSplit(split, 1);
+
+ RDD<Tuple2<Query, Set<String>>> queryActual = JavaPairRDD.toRDD(trainingAndTestingViews[1].union(trainingAndTestingBuys[1]).groupBy(new Function<UserItemEvent, String>() {
+ @Override
+ public String call(UserItemEvent event) throws Exception {
+ return event.getUser();
+ }
+ }).mapToPair(new PairFunction<Tuple2<String, Iterable<UserItemEvent>>, Query, Set<String>>() {
+ @Override
+ public Tuple2<Query, Set<String>> call(Tuple2<String, Iterable<UserItemEvent>> userEvents) throws Exception {
+ Query query = new Query(userEvents._1(), 10, Collections.<String>emptySet(), Collections.<String>emptySet(), Collections.<String>emptySet());
+ Set<String> actualSet = new HashSet<>();
+ for (UserItemEvent event : userEvents._2()) {
+ actualSet.add(event.getItem());
+ }
+ return new Tuple2<>(query, actualSet);
+ }
+ }));
+
+ Tuple3<TrainingData, EmptyParams, RDD<Tuple2<Query, Set<String>>>> setData = new Tuple3<>(new TrainingData(all.getUsers(), all.getItems(), trainingAndTestingViews[0], trainingAndTestingBuys[0]), new EmptyParams(), queryActual);
+
+ return JavaConversions.asScalaIterable(Collections.singletonList(setData)).toSeq();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/DataSourceParams.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/DataSourceParams.java b/src/main/java/org/example/recommendation/DataSourceParams.java
new file mode 100644
index 0000000..4651b92
--- /dev/null
+++ b/src/main/java/org/example/recommendation/DataSourceParams.java
@@ -0,0 +1,15 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.Params;
+
+public class DataSourceParams implements Params{
+ private final String appName;
+
+ public DataSourceParams(String appName) {
+ this.appName = appName;
+ }
+
+ public String getAppName() {
+ return appName;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Item.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Item.java b/src/main/java/org/example/recommendation/Item.java
new file mode 100644
index 0000000..2159beb
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Item.java
@@ -0,0 +1,31 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+import java.util.Set;
+
+public class Item implements Serializable{
+ private final Set<String> categories;
+ private final String entityId;
+
+ public Item(String entityId, Set<String> categories) {
+ this.categories = categories;
+ this.entityId = entityId;
+ }
+
+ public String getEntityId() {
+ return entityId;
+ }
+
+ public Set<String> getCategories() {
+ return categories;
+ }
+
+ @Override
+ public String toString() {
+ return "Item{" +
+ "categories=" + categories +
+ ", entityId='" + entityId + '\'' +
+ '}';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/ItemScore.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/ItemScore.java b/src/main/java/org/example/recommendation/ItemScore.java
new file mode 100644
index 0000000..23c3fdb
--- /dev/null
+++ b/src/main/java/org/example/recommendation/ItemScore.java
@@ -0,0 +1,34 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+
+public class ItemScore implements Serializable, Comparable<ItemScore> {
+ private final String itemEntityId;
+ private final double score;
+
+ public ItemScore(String itemEntityId, double score) {
+ this.itemEntityId = itemEntityId;
+ this.score = score;
+ }
+
+ public String getItemEntityId() {
+ return itemEntityId;
+ }
+
+ public double getScore() {
+ return score;
+ }
+
+ @Override
+ public String toString() {
+ return "ItemScore{" +
+ "itemEntityId='" + itemEntityId + '\'' +
+ ", score=" + score +
+ '}';
+ }
+
+ @Override
+ public int compareTo(ItemScore o) {
+ return Double.valueOf(score).compareTo(o.score);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Model.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Model.java b/src/main/java/org/example/recommendation/Model.java
new file mode 100644
index 0000000..ebf42e5
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Model.java
@@ -0,0 +1,84 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.Params;
+import org.apache.predictionio.controller.PersistentModel;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.Tuple2;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Map;
+
+public class Model implements Serializable, PersistentModel<AlgorithmParams> {
+ private static final Logger logger = LoggerFactory.getLogger(Model.class);
+ private final JavaPairRDD<Integer, double[]> userFeatures;
+ private final JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures;
+ private final JavaPairRDD<String, Integer> userIndex;
+ private final JavaPairRDD<String, Integer> itemIndex;
+ private final JavaRDD<ItemScore> itemPopularityScore;
+ private final Map<String, Item> items;
+
+ public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, Map<String, Item> items) {
+ this.userFeatures = userFeatures;
+ this.indexItemFeatures = indexItemFeatures;
+ this.userIndex = userIndex;
+ this.itemIndex = itemIndex;
+ this.itemPopularityScore = itemPopularityScore;
+ this.items = items;
+ }
+
+ public JavaPairRDD<Integer, double[]> getUserFeatures() {
+ return userFeatures;
+ }
+
+ public JavaPairRDD<Integer, Tuple2<String, double[]>> getIndexItemFeatures() {
+ return indexItemFeatures;
+ }
+
+ public JavaPairRDD<String, Integer> getUserIndex() {
+ return userIndex;
+ }
+
+ public JavaPairRDD<String, Integer> getItemIndex() {
+ return itemIndex;
+ }
+
+ public JavaRDD<ItemScore> getItemPopularityScore() {
+ return itemPopularityScore;
+ }
+
+ public Map<String, Item> getItems() {
+ return items;
+ }
+
+ @Override
+ public boolean save(String id, AlgorithmParams params, SparkContext sc) {
+ userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures");
+ indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures");
+ userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex");
+ itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex");
+ itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore");
+ new JavaSparkContext(sc).parallelize(Collections.singletonList(items)).saveAsObjectFile("/tmp/" + id + "/items");
+
+ logger.info("Saved model to /tmp/" + id);
+ return true;
+ }
+
+ public static Model load(String id, Params params, SparkContext sc) {
+ JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
+ JavaPairRDD<Integer, double[]> userFeatures = JavaPairRDD.<Integer, double[]>fromJavaRDD(jsc.<Tuple2<Integer, double[]>>objectFile("/tmp/" + id + "/userFeatures"));
+ JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = JavaPairRDD.<Integer, Tuple2<String, double[]>>fromJavaRDD(jsc.<Tuple2<Integer, Tuple2<String, double[]>>>objectFile("/tmp/" + id + "/indexItemFeatures"));
+ JavaPairRDD<String, Integer> userIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex"));
+ JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex"));
+ JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore");
+ Map<String, Item> items = jsc.<Map<String, Item>>objectFile("/tmp/" + id + "/items").collect().get(0);
+
+ logger.info("loaded model");
+ return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/PredictedResult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/PredictedResult.java b/src/main/java/org/example/recommendation/PredictedResult.java
new file mode 100644
index 0000000..54d7ade
--- /dev/null
+++ b/src/main/java/org/example/recommendation/PredictedResult.java
@@ -0,0 +1,23 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+import java.util.List;
+
+public class PredictedResult implements Serializable{
+ private final List<ItemScore> itemScores;
+
+ public PredictedResult(List<ItemScore> itemScores) {
+ this.itemScores = itemScores;
+ }
+
+ public List<ItemScore> getItemScores() {
+ return itemScores;
+ }
+
+ @Override
+ public String toString() {
+ return "PredictedResult{" +
+ "itemScores=" + itemScores +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Preparator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Preparator.java b/src/main/java/org/example/recommendation/Preparator.java
new file mode 100644
index 0000000..33beb50
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Preparator.java
@@ -0,0 +1,12 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.java.PJavaPreparator;
+import org.apache.spark.SparkContext;
+
+public class Preparator extends PJavaPreparator<TrainingData, PreparedData> {
+
+ @Override
+ public PreparedData prepare(SparkContext sc, TrainingData trainingData) {
+ return new PreparedData(trainingData);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/PreparedData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/PreparedData.java b/src/main/java/org/example/recommendation/PreparedData.java
new file mode 100644
index 0000000..802b7f2
--- /dev/null
+++ b/src/main/java/org/example/recommendation/PreparedData.java
@@ -0,0 +1,15 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+
+public class PreparedData implements Serializable {
+ private final TrainingData trainingData;
+
+ public PreparedData(TrainingData trainingData) {
+ this.trainingData = trainingData;
+ }
+
+ public TrainingData getTrainingData() {
+ return trainingData;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Query.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Query.java b/src/main/java/org/example/recommendation/Query.java
new file mode 100644
index 0000000..977f566
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Query.java
@@ -0,0 +1,55 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Set;
+
+public class Query implements Serializable{
+ private final String userEntityId;
+ private final int number;
+ private final Set<String> categories;
+ private final Set<String> whitelist;
+ private final Set<String> blacklist;
+
+ public Query(String userEntityId, int number, Set<String> categories, Set<String> whitelist, Set<String> blacklist) {
+ this.userEntityId = userEntityId;
+ this.number = number;
+ this.categories = categories;
+ this.whitelist = whitelist;
+ this.blacklist = blacklist;
+ }
+
+ public String getUserEntityId() {
+ return userEntityId;
+ }
+
+ public int getNumber() {
+ return number;
+ }
+
+ public Set<String> getCategories() {
+ if (categories == null) return Collections.emptySet();
+ return categories;
+ }
+
+ public Set<String> getWhitelist() {
+ if (whitelist == null) return Collections.emptySet();
+ return whitelist;
+ }
+
+ public Set<String> getBlacklist() {
+ if (blacklist == null) return Collections.emptySet();
+ return blacklist;
+ }
+
+ @Override
+ public String toString() {
+ return "Query{" +
+ "userEntityId='" + userEntityId + '\'' +
+ ", number=" + number +
+ ", categories=" + categories +
+ ", whitelist=" + whitelist +
+ ", blacklist=" + blacklist +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/RecommendationEngine.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/RecommendationEngine.java b/src/main/java/org/example/recommendation/RecommendationEngine.java
new file mode 100644
index 0000000..ead9aa7
--- /dev/null
+++ b/src/main/java/org/example/recommendation/RecommendationEngine.java
@@ -0,0 +1,23 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.EmptyParams;
+import org.apache.predictionio.controller.Engine;
+import org.apache.predictionio.controller.EngineFactory;
+import org.apache.predictionio.core.BaseAlgorithm;
+import org.apache.predictionio.core.BaseEngine;
+
+import java.util.Collections;
+import java.util.Set;
+
+public class RecommendationEngine extends EngineFactory {
+
+ @Override
+ public BaseEngine<EmptyParams, Query, PredictedResult, Set<String>> apply() {
+ return new Engine<>(
+ DataSource.class,
+ Preparator.class,
+ Collections.<String, Class<? extends BaseAlgorithm<PreparedData, ?, Query, PredictedResult>>>singletonMap("algo", Algorithm.class),
+ Serving.class
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/Serving.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/Serving.java b/src/main/java/org/example/recommendation/Serving.java
new file mode 100644
index 0000000..80d6c83
--- /dev/null
+++ b/src/main/java/org/example/recommendation/Serving.java
@@ -0,0 +1,12 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.java.LJavaServing;
+import scala.collection.Seq;
+
+public class Serving extends LJavaServing<Query, PredictedResult> {
+
+ @Override
+ public PredictedResult serve(Query query, Seq<PredictedResult> predictions) {
+ return predictions.head();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/TrainingData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/TrainingData.java b/src/main/java/org/example/recommendation/TrainingData.java
new file mode 100644
index 0000000..35af8a0
--- /dev/null
+++ b/src/main/java/org/example/recommendation/TrainingData.java
@@ -0,0 +1,50 @@
+package org.example.recommendation;
+
+import org.apache.predictionio.controller.SanityCheck;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+
+import java.io.Serializable;
+
+public class TrainingData implements Serializable, SanityCheck {
+ private final JavaPairRDD<String, User> users;
+ private final JavaPairRDD<String, Item> items;
+ private final JavaRDD<UserItemEvent> viewEvents;
+ private final JavaRDD<UserItemEvent> buyEvents;
+
+ public TrainingData(JavaPairRDD<String, User> users, JavaPairRDD<String, Item> items, JavaRDD<UserItemEvent> viewEvents, JavaRDD<UserItemEvent> buyEvents) {
+ this.users = users;
+ this.items = items;
+ this.viewEvents = viewEvents;
+ this.buyEvents = buyEvents;
+ }
+
+ public JavaPairRDD<String, User> getUsers() {
+ return users;
+ }
+
+ public JavaPairRDD<String, Item> getItems() {
+ return items;
+ }
+
+ public JavaRDD<UserItemEvent> getViewEvents() {
+ return viewEvents;
+ }
+
+ public JavaRDD<UserItemEvent> getBuyEvents() {
+ return buyEvents;
+ }
+
+ @Override
+ public void sanityCheck() {
+ if (users.isEmpty()) {
+ throw new AssertionError("User data is empty");
+ }
+ if (items.isEmpty()) {
+ throw new AssertionError("Item data is empty");
+ }
+ if (viewEvents.isEmpty()) {
+ throw new AssertionError("View Event data is empty");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/User.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/User.java b/src/main/java/org/example/recommendation/User.java
new file mode 100644
index 0000000..d187a20
--- /dev/null
+++ b/src/main/java/org/example/recommendation/User.java
@@ -0,0 +1,30 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+import java.util.Map;
+
+public class User implements Serializable {
+ private final String entityId;
+ private final Map<String, String> properties;
+
+ public User(String entityId, Map<String, String> properties) {
+ this.entityId = entityId;
+ this.properties = properties;
+ }
+
+ public String getEntityId() {
+ return entityId;
+ }
+
+ public Map<String, String> getProperties() {
+ return properties;
+ }
+
+ @Override
+ public String toString() {
+ return "User{" +
+ "entityId='" + entityId + '\'' +
+ ", properties=" + properties +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/UserItemEvent.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/UserItemEvent.java b/src/main/java/org/example/recommendation/UserItemEvent.java
new file mode 100644
index 0000000..d548a18
--- /dev/null
+++ b/src/main/java/org/example/recommendation/UserItemEvent.java
@@ -0,0 +1,43 @@
+package org.example.recommendation;
+
+import java.io.Serializable;
+
+public class UserItemEvent implements Serializable {
+ private final String user;
+ private final String item;
+ private final long time;
+ private final UserItemEventType type;
+
+ public UserItemEvent(String user, String item, long time, UserItemEventType type) {
+ this.user = user;
+ this.item = item;
+ this.time = time;
+ this.type = type;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public String getItem() {
+ return item;
+ }
+
+ public long getTime() {
+ return time;
+ }
+
+ public UserItemEventType getType() {
+ return type;
+ }
+
+ @Override
+ public String toString() {
+ return "UserItemEvent{" +
+ "user='" + user + '\'' +
+ ", item='" + item + '\'' +
+ ", time=" + time +
+ ", type=" + type +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/UserItemEventType.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/UserItemEventType.java b/src/main/java/org/example/recommendation/UserItemEventType.java
new file mode 100644
index 0000000..f86b411
--- /dev/null
+++ b/src/main/java/org/example/recommendation/UserItemEventType.java
@@ -0,0 +1,5 @@
+package org.example.recommendation;
+
+public enum UserItemEventType {
+ VIEW, BUY
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java b/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java
new file mode 100644
index 0000000..33028eb
--- /dev/null
+++ b/src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java
@@ -0,0 +1,28 @@
+package org.example.recommendation.evaluation;
+
+import org.apache.predictionio.controller.EmptyParams;
+import org.apache.predictionio.controller.EngineParams;
+import org.apache.predictionio.controller.java.JavaEngineParamsGenerator;
+import org.example.recommendation.AlgorithmParams;
+import org.example.recommendation.DataSourceParams;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+public class EvaluationParameter extends JavaEngineParamsGenerator {
+ public EvaluationParameter() {
+ this.setEngineParamsList(
+ Collections.singletonList(
+ new EngineParams(
+ "",
+ new DataSourceParams("javadase"),
+ "",
+ new EmptyParams(),
+ Collections.singletonMap("algo", new AlgorithmParams(1, 10, 10, 0.01, "javadase", Collections.singletonList("view"), true, Arrays.asList("buy", "view"))),
+ "",
+ new EmptyParams()
+ )
+ )
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java b/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java
new file mode 100644
index 0000000..2bafc7b
--- /dev/null
+++ b/src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java
@@ -0,0 +1,28 @@
+package org.example.recommendation.evaluation;
+
+import org.apache.predictionio.controller.Engine;
+import org.apache.predictionio.controller.java.JavaEvaluation;
+import org.apache.predictionio.core.BaseAlgorithm;
+import org.example.recommendation.Algorithm;
+import org.example.recommendation.DataSource;
+import org.example.recommendation.PredictedResult;
+import org.example.recommendation.Preparator;
+import org.example.recommendation.PreparedData;
+import org.example.recommendation.Query;
+import org.example.recommendation.Serving;
+
+import java.util.Collections;
+
+public class EvaluationSpec extends JavaEvaluation {
+ public EvaluationSpec() {
+ this.setEngineMetric(
+ new Engine<>(
+ DataSource.class,
+ Preparator.class,
+ Collections.<String, Class<? extends BaseAlgorithm<PreparedData, ?, Query, PredictedResult>>>singletonMap("algo", Algorithm.class),
+ Serving.class
+ ),
+ new PrecisionMetric()
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java b/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java
new file mode 100644
index 0000000..e412fd5
--- /dev/null
+++ b/src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java
@@ -0,0 +1,62 @@
+package org.example.recommendation.evaluation;
+
+import org.apache.predictionio.controller.EmptyParams;
+import org.apache.predictionio.controller.Metric;
+import org.apache.predictionio.controller.java.SerializableComparator;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.rdd.RDD;
+import org.example.recommendation.ItemScore;
+import org.example.recommendation.PredictedResult;
+import org.example.recommendation.Query;
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class PrecisionMetric extends Metric<EmptyParams, Query, PredictedResult, Set<String>, Double> {
+
+ private static final class MetricComparator implements SerializableComparator<Double> {
+ @Override
+ public int compare(Double o1, Double o2) {
+ return o1.compareTo(o2);
+ }
+ }
+
+ public PrecisionMetric() {
+ super(new MetricComparator());
+ }
+
+ @Override
+ public Double calculate(SparkContext sc, Seq<Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>>> qpas) {
+ List<Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>>> sets = JavaConversions.asJavaList(qpas);
+ List<Double> allSetResults = new ArrayList<>();
+
+ for (Tuple2<EmptyParams, RDD<Tuple3<Query, PredictedResult, Set<String>>>> set : sets) {
+ List<Double> setResults = set._2().toJavaRDD().map(new Function<Tuple3<Query, PredictedResult, Set<String>>, Double>() {
+ @Override
+ public Double call(Tuple3<Query, PredictedResult, Set<String>> qpa) throws Exception {
+ Set<String> predicted = new HashSet<>();
+ for (ItemScore itemScore : qpa._2().getItemScores()) {
+ predicted.add(itemScore.getItemEntityId());
+ }
+ Set<String> intersection = new HashSet<>(predicted);
+ intersection.retainAll(qpa._3());
+
+ return 1.0 * intersection.size() / qpa._2().getItemScores().size();
+ }
+ }).collect();
+
+ allSetResults.addAll(setResults);
+ }
+ double sum = 0.0;
+ for (Double value : allSetResults) sum += value;
+
+ return sum / allSetResults.size();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/template/recommendation/Algorithm.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/template/recommendation/Algorithm.java b/src/main/java/org/template/recommendation/Algorithm.java
deleted file mode 100644
index 24b4e5c..0000000
--- a/src/main/java/org/template/recommendation/Algorithm.java
+++ /dev/null
@@ -1,409 +0,0 @@
-package org.template.recommendation;
-
-import com.google.common.collect.Sets;
-import io.prediction.controller.java.PJavaAlgorithm;
-import io.prediction.data.storage.Event;
-import io.prediction.data.store.java.LJavaEventStore;
-import io.prediction.data.store.java.OptionHelper;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.mllib.recommendation.ALS;
-import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
-import org.apache.spark.mllib.recommendation.Rating;
-import org.apache.spark.rdd.RDD;
-import org.jblas.DoubleMatrix;
-import org.joda.time.DateTime;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import scala.Option;
-import scala.Tuple2;
-import scala.concurrent.duration.Duration;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.TimeUnit;
-
-public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> {
-
- private static final Logger logger = LoggerFactory.getLogger(Algorithm.class);
- private final AlgorithmParams ap;
-
- public Algorithm(AlgorithmParams ap) {
- this.ap = ap;
- }
-
- @Override
- public Model train(SparkContext sc, PreparedData preparedData) {
- TrainingData data = preparedData.getTrainingData();
-
- // user stuff
- JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
- @Override
- public String call(Tuple2<String, User> idUser) throws Exception {
- return idUser._1();
- }
- }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
- @Override
- public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
- return new Tuple2<>(element._1(), element._2().intValue());
- }
- });
- final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
-
- // item stuff
- JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
- @Override
- public String call(Tuple2<String, Item> idItem) throws Exception {
- return idItem._1();
- }
- }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
- @Override
- public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
- return new Tuple2<>(element._1(), element._2().intValue());
- }
- });
- final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
- JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
- @Override
- public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
- return element.swap();
- }
- });
- final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
-
- // ratings stuff
- JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
- @Override
- public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
- Integer userIndex = userIndexMap.get(viewEvent.getUser());
- Integer itemIndex = itemIndexMap.get(viewEvent.getItem());
-
- return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
- }
- }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
- return (element != null);
- }
- }).reduceByKey(new Function2<Integer, Integer, Integer>() {
- @Override
- public Integer call(Integer integer, Integer integer2) throws Exception {
- return integer + integer2;
- }
- }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
- @Override
- public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
- return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
- }
- });
-
- if (ratings.isEmpty())
- throw new AssertionError("Please check if your events contain valid user and item ID.");
-
- // MLlib ALS stuff
- MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
- JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
- @Override
- public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
- return new Tuple2<>((Integer) element._1(), element._2());
- }
- });
- JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
- @Override
- public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
- return new Tuple2<>((Integer) element._1(), element._2());
- }
- });
-
- // popularity scores
- JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
- @Override
- public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
- Integer userIndex = userIndexMap.get(buyEvent.getUser());
- Integer itemIndex = itemIndexMap.get(buyEvent.getItem());
-
- return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
- }
- }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
- return (element != null);
- }
- }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
- @Override
- public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
- return new Tuple2<>(element._1()._2(), element._2());
- }
- }).reduceByKey(new Function2<Integer, Integer, Integer>() {
- @Override
- public Integer call(Integer integer, Integer integer2) throws Exception {
- return integer + integer2;
- }
- }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
- @Override
- public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
- return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
- }
- });
-
- JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
-
- return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap());
- }
-
- @Override
- public PredictedResult predict(Model model, final Query query) {
- final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
- @Override
- public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
- return userIndex._1().equals(query.getUserEntityId());
- }
- });
-
- double[] userFeature = null;
- if (!matchedUser.isEmpty()) {
- final Integer matchedUserIndex = matchedUser.first()._2();
- userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
- return element._1().equals(matchedUserIndex);
- }
- }).first()._2();
- }
-
- if (userFeature != null) {
- return new PredictedResult(topItemsForUser(userFeature, model, query));
- } else {
- List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);
- if (recentProductFeatures.isEmpty()) {
- return new PredictedResult(mostPopularItems(model, query));
- } else {
- return new PredictedResult(similarItems(recentProductFeatures, model, query));
- }
- }
- }
-
- @Override
- public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) {
- List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect();
- List<Tuple2<Object, PredictedResult>> results = new ArrayList<>();
-
- for (Tuple2<Object, Query> indexQuery : indexQueries) {
- results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2())));
- }
-
- return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd();
- }
-
- private List<double[]> getRecentProductFeatures(Query query, Model model) {
- try {
- List<double[]> result = new ArrayList<>();
-
- List<Event> events = LJavaEventStore.findByEntity(
- ap.getAppName(),
- "user",
- query.getUserEntityId(),
- OptionHelper.<String>none(),
- OptionHelper.some(ap.getSimilarItemEvents()),
- OptionHelper.some(OptionHelper.some("item")),
- OptionHelper.<Option<String>>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.some(10),
- true,
- Duration.apply(10, TimeUnit.SECONDS));
-
- for (final Event event : events) {
- if (event.targetEntityId().isDefined()) {
- JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
- @Override
- public Boolean call(Tuple2<String, Integer> element) throws Exception {
- return element._1().equals(event.targetEntityId().get());
- }
- });
-
- final Integer itemIndex = filtered.first()._2();
-
- if (!filtered.isEmpty()) {
-
- JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
- return itemIndex.equals(element._1());
- }
- });
-
- List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
- if (oneIndexItemFeatures.size() > 0) {
- result.add(oneIndexItemFeatures.get(0)._2()._2());
- }
- }
- }
- }
-
- return result;
- } catch (Exception e) {
- logger.error("Error reading recent events for user " + query.getUserEntityId());
- throw new RuntimeException(e.getMessage(), e);
- }
- }
-
- private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
- final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
-
- JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
- @Override
- public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
- return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
- }
- });
-
- itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- return sortAndTake(itemScores, query.getNumber());
- }
-
- private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
- JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
- @Override
- public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
- double similarity = 0.0;
- for (double[] recentFeature : recentProductFeatures) {
- similarity += cosineSimilarity(element._2()._2(), recentFeature);
- }
-
- return new ItemScore(element._2()._1(), similarity);
- }
- });
-
- itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- return sortAndTake(itemScores, query.getNumber());
- }
-
- private List<ItemScore> mostPopularItems(Model model, Query query) {
- JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- return sortAndTake(itemScores, query.getNumber());
- }
-
- private double cosineSimilarity(double[] a, double[] b) {
- DoubleMatrix matrixA = new DoubleMatrix(a);
- DoubleMatrix matrixB = new DoubleMatrix(b);
-
- return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
- }
-
- private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) {
- return all.sortBy(new Function<ItemScore, Double>() {
- @Override
- public Double call(ItemScore itemScore) throws Exception {
- return itemScore.getScore();
- }
- }, false, all.partitions().size()).take(number);
- }
-
- private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) {
- final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
- final Set<String> unavailableItemEntityIds = unavailableItemEntityIds();
-
- return all.filter(new Function<ItemScore, Boolean>() {
- @Override
- public Boolean call(ItemScore itemScore) throws Exception {
- Item item = items.get(itemScore.getItemEntityId());
-
- return (item != null
- && passWhitelistCriteria(whitelist, item.getEntityId())
- && passBlacklistCriteria(blacklist, item.getEntityId())
- && passCategoryCriteria(categories, item)
- && passUnseenCriteria(seenItemEntityIds, item.getEntityId())
- && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId()));
- }
- });
- }
-
- private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) {
- return (whitelist.isEmpty() || whitelist.contains(itemEntityId));
- }
-
- private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) {
- return !blacklist.contains(itemEntityId);
- }
-
- private boolean passCategoryCriteria(Set<String> categories, Item item) {
- return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0);
- }
-
- private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) {
- return !seen.contains(itemEntityId);
- }
-
- private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) {
- return !unavailableItemEntityIds.contains(entityId);
- }
-
- private Set<String> unavailableItemEntityIds() {
- try {
- List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity(
- ap.getAppName(),
- "constraint",
- "unavailableItems",
- OptionHelper.<String>none(),
- OptionHelper.some(Collections.singletonList("$set")),
- OptionHelper.<Option<String>>none(),
- OptionHelper.<Option<String>>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.some(1),
- true,
- Duration.apply(10, TimeUnit.SECONDS));
-
- if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet();
-
- Event unavailableConstraint = unavailableConstraintEvents.get(0);
-
- List<String> unavailableItems = unavailableConstraint.properties().getStringList("items");
-
- return new HashSet<>(unavailableItems);
- } catch (Exception e) {
- logger.error("Error reading constraint events");
- throw new RuntimeException(e.getMessage(), e);
- }
- }
-
- private Set<String> seenItemEntityIds(String userEntityId) {
- if (!ap.isUnseenOnly()) return Collections.emptySet();
-
- try {
- Set<String> result = new HashSet<>();
- List<Event> seenEvents = LJavaEventStore.findByEntity(
- ap.getAppName(),
- "user",
- userEntityId,
- OptionHelper.<String>none(),
- OptionHelper.some(ap.getSeenItemEvents()),
- OptionHelper.some(OptionHelper.some("item")),
- OptionHelper.<Option<String>>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.<DateTime>none(),
- OptionHelper.<Integer>none(),
- true,
- Duration.apply(10, TimeUnit.SECONDS));
-
- for (Event event : seenEvents) {
- result.add(event.targetEntityId().get());
- }
-
- return result;
- } catch (Exception e) {
- logger.error("Error reading seen events for user " + userEntityId);
- throw new RuntimeException(e.getMessage(), e);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio-template-java-ecom-recommender/blob/36995dfc/src/main/java/org/template/recommendation/AlgorithmParams.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/template/recommendation/AlgorithmParams.java b/src/main/java/org/template/recommendation/AlgorithmParams.java
deleted file mode 100644
index 0466334..0000000
--- a/src/main/java/org/template/recommendation/AlgorithmParams.java
+++ /dev/null
@@ -1,74 +0,0 @@
-package org.template.recommendation;
-
-import io.prediction.controller.Params;
-
-import java.util.List;
-
-public class AlgorithmParams implements Params{
- private final long seed;
- private final int rank;
- private final int iteration;
- private final double lambda;
- private final String appName;
- private final List<String> similarItemEvents;
- private final boolean unseenOnly;
- private final List<String> seenItemEvents;
-
-
- public AlgorithmParams(long seed, int rank, int iteration, double lambda, String appName, List<String> similarItemEvents, boolean unseenOnly, List<String> seenItemEvents) {
- this.seed = seed;
- this.rank = rank;
- this.iteration = iteration;
- this.lambda = lambda;
- this.appName = appName;
- this.similarItemEvents = similarItemEvents;
- this.unseenOnly = unseenOnly;
- this.seenItemEvents = seenItemEvents;
- }
-
- public long getSeed() {
- return seed;
- }
-
- public int getRank() {
- return rank;
- }
-
- public int getIteration() {
- return iteration;
- }
-
- public double getLambda() {
- return lambda;
- }
-
- public String getAppName() {
- return appName;
- }
-
- public List<String> getSimilarItemEvents() {
- return similarItemEvents;
- }
-
- public boolean isUnseenOnly() {
- return unseenOnly;
- }
-
- public List<String> getSeenItemEvents() {
- return seenItemEvents;
- }
-
- @Override
- public String toString() {
- return "AlgorithmParams{" +
- "seed=" + seed +
- ", rank=" + rank +
- ", iteration=" + iteration +
- ", lambda=" + lambda +
- ", appName='" + appName + '\'' +
- ", similarItemEvents=" + similarItemEvents +
- ", unseenOnly=" + unseenOnly +
- ", seenItemEvents=" + seenItemEvents +
- '}';
- }
-}