You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/28 03:17:27 UTC
[1/3] incubator-hivemall git commit: Close #117,
Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning
recommendation algorithm
Repository: incubator-hivemall
Updated Branches:
refs/heads/master c2b95783c -> 995b9a885
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_slim.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_slim.md b/docs/gitbook/recommend/movielens_slim.md
new file mode 100644
index 0000000..60d52b3
--- /dev/null
+++ b/docs/gitbook/recommend/movielens_slim.md
@@ -0,0 +1,589 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Hivemall supports a neighborhood-learning scheme using SLIM.
+SLIM is a representative of neighborhood-learning recommendation algorithm introduced in the following paper:
+
+- Xia Ning and George Karypis, [SLIM: Sparse Linear Methods for Top-N Recommender Systems](https://dl.acm.org/citation.cfm?id=2118303), Proc. ICDM, 2011.
+
+_Caution: SLIM is supported from Hivemall v0.5-rc.1 or later._
+
+<!-- toc -->
+
+# SLIM optimization objective
+
+The optimization objective of [SLIM]((http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf)) is similar to Elastic Net (L1+L2 regularization) with additional constraints as follows:
+
+$$
+\begin{aligned}
+& \;{\tiny\begin{matrix}\\ \normalsize \text{minimize} \\ ^{\scriptsize w_{j}}\end{matrix}}\;
+&& \frac{1}{2}\Vert r_{j} - Rw_{j} \Vert_2^2 + \frac{\beta}{2} \Vert w_{j} \Vert_2^2 + \lambda \Vert w_{j} \Vert_1 \\
+& \text{subject to}
+&& w_{j} \geq 0 \\
+&&& diag(W)= 0
+\end{aligned}
+$$
+
+# Data preparation
+
+## Rating binarization
+
+In this article, each user-movie matrix element is binarized to reduce training samples and consider only high rated movies whose rating is 4 or 5. So, every matrix element having a lower rating than 4 is not used for training.
+
+```sql
+SET hivevar:seed=31;
+
+DROP TABLE ratings2;
+CREATE TABLE ratings2 as
+select
+ rand(${seed}) as rnd,
+ userid,
+ movieid as itemid,
+ cast(1.0 as float) as rating -- double is also accepted
+from
+ ratings
+where rating >= 4.
+;
+```
+
+`rnd` field is appended for each record to split `ratings2` into training and testing data later.
+
+Binarization is an optional step, and you can use raw rating values to train a SLIM model.
+
+## Splitting dataset
+
+To evaluate a recommendation model, this tutorial uses two type cross validations:
+
+- Leave-one-out cross validation
+- $$K$$-hold cross validation
+
+The former is used in the [SLIM's paper](http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf) and the latter is used in [Mendeley's slide](http://slideshare.net/MarkLevy/efficient-slides/).
+
+### Leave-one-out cross validation
+
+For leave-one-out cross validation, the dataset is split into a training set and a testing set by randomly selecting one of the non-zero entries of each user and placing it into the testing set.
+In the following query, the movie has the smallest `rnd` value is used as test data (`testing` table) per a user.
+And, the others are used as training data (`training` table).
+
+When we select slim's best hyperparameters, different test data is used in [evaluation section](#evaluation) several times.
+
+``` sql
+DROP TABLE testing;
+CREATE TABLE testing
+as
+WITH top_k as (
+ select
+ each_top_k(1, userid, rnd, userid, itemid, rating)
+ as (rank, rnd, userid, itemid, rating)
+ from (
+ select * from ratings2
+ CLUSTER BY userid
+ ) t
+)
+select
+ userid, itemid, rating
+from
+ top_k
+;
+
+DROP TABLE training;
+CREATE TABLE training as
+select
+ l.*
+from
+ ratings2 l
+ LEFT OUTER JOIN testing r ON (l.userid=r.userid and l.itemid=r.itemid)
+where
+ r.itemid IS NULL -- anti join
+;
+```
+
+### $$K$$-hold corss validation
+
+When $$K=2$$, the dataset is divided into training data and testing dataset.
+The numbers of training and testing samples roughly equal.
+
+When we select slim's best hyperparameters, you'll first train a SLIM prediction model from training data and evaluate the prediction model by testing data.
+
+Optionally, you can switch training data with testing data and evaluate again.
+
+```sql
+DROP TABLE testing;
+CREATE TABLE testing
+as
+select * from ratings2
+where rnd >= 0.5
+;
+
+DROP TABLE training;
+CREATE TABLE training
+as
+select * from ratings2
+where rnd < 0.5
+;
+```
+
+> #### Note
+>
+> In the following section excluding evaluation section,
+> we will show the example of queries and its results based on $$K$$-hold cross validation case.
+> But, this article's queries are valid for leave-one-out cross validation.
+
+## Precompute movie-movie similarity
+
+SLIM needs top-$$k$$ most similar movies for each movie to the approximate user-item matrix.
+Here, we particularly focus on [DIMSUM](item_based_cf.html#dimsum-approximated-all-pairs-cosine-similarity-computation),
+an efficient and approximated similarity computation scheme.
+
+Because we set `k=20`, the output has 20 most-similar movies per `itemid`.
+We can adjust trade-off between training and prediction time and precision of matrix approximation by varying `k`.
+Larger `k` is the better approximation for raw user-item matrix, but training time and memory usage tend to increase.
+
+[As we explained in the general introduction of item-based CF](item_based_cf.html#dimsum-approximated-all-pairs-cosine-similarity-computation.md),
+following query finds top-$$k$$ nearest-neighborhood movies for each movie:
+
+```sql
+set hivevar:k=20;
+
+DROP TABLE knn_train;
+CREATE TABLE knn_train
+as
+with item_magnitude as (
+ select
+ to_map(j, mag) as mags
+ from (
+ select
+ itemid as j,
+ l2_norm(rating) as mag
+ from
+ training
+ group by
+ itemid
+ ) t0
+),
+item_features as (
+ select
+ userid as i,
+ collect_list(
+ feature(itemid, rating)
+ ) as feature_vector
+ from
+ training
+ group by
+ userid
+),
+partial_result as (
+ select
+ dimsum_mapper(f.feature_vector, m.mags, '-threshold 0.1 -int_feature')
+ as (itemid, other, s)
+ from
+ item_features f
+ CROSS JOIN item_magnitude m
+),
+similarity as (
+ select
+ itemid,
+ other,
+ sum(s) as similarity
+ from
+ partial_result
+ group by
+ itemid, other
+),
+topk as (
+ select
+ each_top_k(
+ ${k}, itemid, similarity, -- use top k items
+ itemid, other
+ ) as (rank, similarity, itemid, other)
+ from (
+ select * from similarity
+ CLUSTER BY itemid
+ ) t
+)
+select
+ itemid, other, similarity
+from
+ topk
+;
+```
+
+| itemid | other | similarity |
+|:---:|:---:|:---|
+| 1 | 3114 | 0.28432244 |
+| 1 | 1265 | 0.25180137 |
+| 1 | 2355 | 0.24781825 |
+| 1 | 2396 | 0.24435896 |
+| 1 | 588 | 0.24359442 |
+|...|...|...|
+
+
+> #### Caution
+> To run the query above, you may need to run the following statements:
+```sql
+set hive.strict.checks.cartesian.product=false;
+set hive.mapred.mode=nonstrict;
+```
+
+## Create training input tables
+
+Here, we prepare input tables for SLIM training.
+
+SLIM input consists of the following columns in `slim_training_item`:
+
+- `i`: axis item id
+- `Ri`: the user-rating vector of the axis item $$i$$ expressed as `map<userid, rating>`.
+- `knn_i`: top-$$K$$ similar item matrix of item $$i$$; the user-item rating matrix is expressed as `map<userid, map<itemid, rating>>`.
+- `j`: an item id in `knn_i`.
+- `Rj`: the user-rating vector of the item $$j$$ expressed as `map<userid, rating>`.
+
+```sql
+DROP TABLE item_matrix;
+CREATE table item_matrix as
+select
+ itemid as i,
+ to_map(userid, rating) as R_i
+from
+ training
+group by
+ itemid;
+
+-- Temporary set off map join because the following query does not work well for map join
+set hive.auto.convert.join=false;
+-- set mapred.reduce.tasks=64;
+
+-- Create SLIM input features
+DROP TABLE slim_training_item;
+CREATE TABLE slim_training_item as
+WITH knn_item_user_matrix as (
+ select
+ l.itemid,
+ r.userid,
+ to_map(l.other, r.rating) ratings
+ from
+ knn_train l
+ JOIN training r ON (l.other = r.itemid)
+ group by
+ l.itemid, r.userid
+),
+knn_item_matrix as (
+ select
+ itemid as i,
+ to_map(userid, ratings) as KNN_i -- map<userid, map<itemid, rating>>
+ from
+ knn_item_user_matrix
+ group by
+ itemid
+)
+select
+ l.itemid as i,
+ r1.R_i,
+ r2.knn_i,
+ l.other as j,
+ r3.R_i as R_j
+from
+ knn_train l
+ JOIN item_matrix r1 ON (l.itemid = r1.i)
+ JOIN knn_item_matrix r2 ON (l.itemid = r2.i)
+ JOIN item_matrix r3 ON (l.other = r3.i)
+;
+
+-- set to the default value
+set hive.auto.convert.join=true;
+```
+
+# Training
+
+## Build a prediction model by SLIM
+
+`train_slim` function outputs the nonzero elements of an item-item matrix.
+For item recommendation or prediction, this matrix is stored into the table named `slim_model`.
+
+```sql
+DROP TABLE slim_model;
+CREATE TABLE slim_model as
+select
+ i, nn, avg(w) as w
+from (
+ select
+ train_slim(i, r_i, knn_i, j, r_j) as (i, nn, w)
+ from (
+ select * from slim_training_item
+ CLUSTER BY i
+ ) t1
+) t2
+group by i, nn
+;
+```
+
+## Usage of `train_slim`
+
+You can obtain information about `train_slim` function and its arguments by giving `-help` option as follows:
+
+``` sql
+select train_slim("-help");
+```
+
+``` sql
+usage: train_slim( int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI,
+ int j, map<int, double> r_j [, constant string options])
+ - Returns row index, column index and non-zero weight value of prediction model
+ [-cv_rate <arg>] [-disable_cv] [-help] [-iters <arg>] [-l1 <arg>] [-l2 <arg>]
+ -cv_rate,--convergence_rate <arg> Threshold to determine convergence
+ [default: 0.005]
+ -disable_cv,--disable_cvtest Whether to disable convergence check
+ [default: enabled]
+ -help Show function help
+ -iters,--iterations <arg> The number of iterations for
+ coordinate descent [default: 30]
+ -l1,--l1coefficient <arg> Coefficient for l1 regularizer
+ [default: 0.001]
+ -l2,--l2coefficient <arg> Coefficient for l2 regularizer
+ [default: 0.0005]
+```
+
+# Prediction and recommendation
+
+Here, we predict ratng values of binarized user-item rating matrix of testing dataset based on ratings in training dataset.
+
+Based on predicted rating scores, we can recommend top-k items for each user that he or she will be likely to put high scores.
+
+## Predict unknown value of user-item matrix
+
+Based on known ratings and SLIM weight matrix, we can predict unknown values in the user-item matrix in `predicted`.
+SLIM predicts ratings of user-item pairs based on top-$$K$$ similar items.
+
+The `predict_pair` table represents candidates for recommended user-movie pairs, excluding known ratings in training dataset.
+
+```sql
+CREATE OR REPLACE VIEW predict_pair
+as
+WITH testing_users as (
+ select DISTINCT(userid) as userid from testing
+),
+training_items as (
+ select DISTINCT(itemid) as itemid from training
+),
+user_items as (
+ select
+ l.userid,
+ r.itemid
+ from
+ testing_users l
+ CROSS JOIN training_items r
+)
+select
+ l.userid,
+ l.itemid
+from
+ user_items l
+ LEFT OUTER JOIN training r ON (l.userid=r.userid and l.itemid=r.itemid)
+where
+ r.itemid IS NULL -- anti join
+;
+``
+
+```
+-- optionally set the mean/default value of prediction
+set hivevar:mu=0.0;
+
+DROP TABLE predicted;
+CREATE TABLE predicted
+as
+WITH knn_exploded as (
+ select
+ l.userid as u,
+ l.itemid as i, -- axis
+ r1.other as k, -- other
+ r2.rating as r_uk
+ from
+ predict_pair l
+ LEFT OUTER JOIN knn_train r1
+ ON (r1.itemid = l.itemid)
+ JOIN training r2
+ ON (r2.userid = l.userid and r2.itemid = r1.other)
+)
+select
+ l.u as userid,
+ l.i as itemid,
+ coalesce(sum(l.r_uk * r.w), ${mu}) as predicted
+ -- coalesce(sum(l.r_uk * r.w)) as predicted
+from
+ knn_exploded l
+ LEFT OUTER JOIN slim_model r ON (l.i = r.i and l.k = r.nn)
+group by
+ l.u, l.i
+;
+```
+
+> #### Caution
+> When $$k$$ is small, slim predicted value may be `null`. Then, `$mu` replaces `null` value.
+> The mean value of item ratings is a good choice for `$mu`.
+
+## Top-$$K$$ item recommendation for each user
+
+Here, we recommend top-3 items for each user based on predicted values.
+
+```sql
+SET hivevar:k=3;
+
+DROP TABLE IF EXISTS recommend;
+CREATE TABLE recommend
+as
+WITH top_n as (
+ select
+ each_top_k(${k}, userid, predicted, userid, itemid)
+ as (rank, predicted, userid, itemid)
+ from (
+ select * from predicted
+ CLUSTER BY userid
+ ) t
+)
+select
+ userid,
+ collect_list(itemid) as items
+from
+ top_n
+group by
+ userid
+;
+
+select * from recommend limit 5;
+```
+
+| userid | items |
+|:---:|:---:|
+| 1 | [364,594,2081] |
+| 2 | [2028,3256,589] |
+| 3 | [260,1291,2791] |
+| 4 | [1196,1200,1210] |
+| 5 | [3813,1366,89] |
+|...|...|
+
+# Evaluation
+
+## Top-$$K$$ ranking measures: Hit-Rate@K, MRR@K, and Precision@K
+
+In this section, `Hit-Rate@k`, `MRR@k`, and `Precision@k` are computed based on recommended items.
+
+[`Precision@K`](../eval/rank.html#precision-at-k) is a good evaluation measure for $$K$$-hold cross validation.
+
+On the other hand, `Hit-Rate` and [`Mean Reciprocal Rank`](https://en.wikipedia.org/wiki/Mean_reciprocal_rank) (i.e., Average Reciprocal Hit-Rate) are good evaluation measures for leave-one-out cross validation.
+
+```sql
+SET hivevar:n=10;
+
+WITH top_k as (
+ select
+ each_top_k(${n}, userid, predicted, userid, itemid)
+ as (rank, predicted, userid, itemid)
+ from (
+ select * from predicted
+ CLUSTER BY userid
+ ) t
+),
+rec_items as (
+ select
+ userid,
+ collect_list(itemid) as items
+ from
+ top_k
+ group by
+ userid
+),
+ground_truth as (
+ select
+ userid,
+ collect_list(itemid) as truth
+ from
+ testing
+ group by
+ userid
+)
+select
+ hitrate(l.items, r.truth) as hitrate,
+ mrr(l.items, r.truth) as mrr,
+ precision_at(l.items, r.truth) as prec
+from
+ rec_items l
+ join ground_truth r on (l.userid=r.userid)
+;
+```
+
+### Leave-one-out result
+
+| hitrate | mrr | prec |
+|:-------:|:---:|:----:|
+| 0.21517309922146763 | 0.09377752536606271 | 0.021517309922146725 |
+
+Hit Rate and MRR are similar to ones in [the result of Table II in Slim's paper](http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf)
+
+### $$K$$-hold result
+
+| hitrate | mrr | prec |
+|:-------:|:---:|:----:|
+| 0.8952775476387739 | 1.1751514972186057 | 0.3564871582435789 |
+
+Precision value is similar to [the result of Mendeley's slide](https://www.slideshare.net/MarkLevy/efficient-slides/13).
+
+## Ranking measures: MRR
+
+In this example, whole recommended items are evaluated using MRR.
+
+``` sql
+WITH rec_items as (
+ select
+ userid,
+ to_ordered_list(itemid, predicted, '-reverse') as items
+ from
+ predicted
+ group by
+ userid
+),
+ground_truth as (
+ select
+ userid,
+ collect_list(itemid) as truth
+ from
+ testing
+ group by
+ userid
+)
+select
+ mrr(l.items, r.truth) as mrr
+from
+ rec_items l
+ join ground_truth r on (l.userid=r.userid)
+;
+```
+
+### Leave-one-out result
+
+| mrr |
+|:---:|
+| 0.10782647321821472 |
+```
+
+### $$K$$-hold result
+
+| mrr |
+|:---:|
+| 0.6179983058881773 |
+
+This MRR value is similar to one in [the Mendeley's slide](https://www.slideshare.net/MarkLevy/efficient-slides/13).
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 8cdd371..d2f0b9f 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -591,6 +591,9 @@ CREATE FUNCTION precision_at as 'hivemall.evaluation.PrecisionUDAF' USING JAR '$
DROP FUNCTION IF EXISTS recall_at;
CREATE FUNCTION recall_at as 'hivemall.evaluation.RecallUDAF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS hitrate;
+CREATE FUNCTION hitrate as 'hivemall.evaluation.HitRateUDAF' USING JAR '${hivemall_jar}';
+
DROP FUNCTION IF EXISTS mrr;
CREATE FUNCTION mrr as 'hivemall.evaluation.MRRUDAF' USING JAR '${hivemall_jar}';
@@ -718,6 +721,13 @@ CREATE FUNCTION rf_ensemble as 'hivemall.smile.tools.RandomForestEnsembleUDAF' U
DROP FUNCTION IF EXISTS guess_attribute_types;
CREATE FUNCTION guess_attribute_types as 'hivemall.smile.tools.GuessAttributesUDF' USING JAR '${hivemall_jar}';
+--------------------
+-- Recommendation --
+--------------------
+
+DROP FUNCTION IF EXISTS train_slim;
+CREATE FUNCTION train_slim as 'hivemall.recommend.SlimUDTF' USING JAR '${hivemall_jar}';
+
------------------------------
-- XGBoost related features --
------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 756c57a..0ef36c3 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -583,6 +583,9 @@ create temporary function precision_at as 'hivemall.evaluation.PrecisionUDAF';
drop temporary function if exists recall_at;
create temporary function recall_at as 'hivemall.evaluation.RecallUDAF';
+drop temporary function if exists hitrate;
+create temporary function hitrate as 'hivemall.evaluation.HitRateUDAF';
+
drop temporary function if exists mrr;
create temporary function mrr as 'hivemall.evaluation.MRRUDAF';
@@ -710,6 +713,13 @@ create temporary function rf_ensemble as 'hivemall.smile.tools.RandomForestEnsem
drop temporary function if exists guess_attribute_types;
create temporary function guess_attribute_types as 'hivemall.smile.tools.GuessAttributesUDF';
+--------------------
+-- Recommendation --
+--------------------
+
+drop temporary function if exists train_slim;
+create temporary function train_slim as 'hivemall.recommend.SlimUDTF';
+
--------------------------------------------------------------------------------------------------
-- macros available from hive 0.12.0
-- see https://issues.apache.org/jira/browse/HIVE-2655
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index bddbc85..97307c2 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -567,6 +567,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION precision_at AS 'hivemall.evaluation.P
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS recall_at")
sqlContext.sql("CREATE TEMPORARY FUNCTION recall_at AS 'hivemall.evaluation.RecallUDAF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS hitrate")
+sqlContext.sql("CREATE TEMPORARY FUNCTION hitrate AS 'hivemall.evaluation.HitRateUDAF'")
+
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS mrr")
sqlContext.sql("CREATE TEMPORARY FUNCTION mrr AS 'hivemall.evaluation.MRRUDAF'")
@@ -695,3 +698,10 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION guess_attribute_types AS 'hivemall.smi
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_gradient_tree_boosting_classifier")
sqlContext.sql("CREATE TEMPORARY FUNCTION train_gradient_tree_boosting_classifier AS 'hivemall.smile.classification.GradientTreeBoostingClassifierUDTF'")
+
+/**
+ * Recommendation
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_slim")
+sqlContext.sql("CREATE TEMPORARY FUNCTION train_slim AS 'hivemall.recommend.SlimUDTF'")
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index c59b120..a281b72 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -180,6 +180,8 @@ create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF';
create temporary function add_field_indicies as 'hivemall.ftvec.trans.AddFieldIndicesUDF';
create temporary function to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList';
create temporary function singularize as 'hivemall.tools.text.SingularizeUDF';
+create temporary function train_slim as 'hivemall.recommend.SlimUDTF';
+create temporary function hitrate as 'hivemall.evaluation.HitRateUDAF';
-- NLP features
create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
[2/3] incubator-hivemall git commit: Close #117,
Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning
recommendation algorithm
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/recommend/SlimUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/recommend/SlimUDTF.java b/core/src/main/java/hivemall/recommend/SlimUDTF.java
new file mode 100644
index 0000000..e205c18
--- /dev/null
+++ b/core/src/main/java/hivemall/recommend/SlimUDTF.java
@@ -0,0 +1,759 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.math.matrix.sparse.DoKFloatMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import hivemall.utils.lang.mutable.MutableDouble;
+import hivemall.utils.lang.mutable.MutableInt;
+import hivemall.utils.lang.mutable.MutableObject;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+/**
+ * Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
+ *
+ * <pre>
+ * Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N Recommender Systems, Proc. ICDM, 2011.
+ * </pre>
+ */
+@Description(
+ name = "train_slim",
+ value = "_FUNC_( int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]) "
+ + "- Returns row index, column index and non-zero weight value of prediction model")
+public class SlimUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(SlimUDTF.class);
+
+ //--------------------------------------------
+ // intput OIs
+
+ private PrimitiveObjectInspector itemIOI;
+ private PrimitiveObjectInspector itemJOI;
+ private MapObjectInspector riOI;
+ private MapObjectInspector rjOI;
+
+ private MapObjectInspector knnItemsOI;
+ private PrimitiveObjectInspector knnItemsKeyOI;
+ private MapObjectInspector knnItemsValueOI;
+ private PrimitiveObjectInspector knnItemsValueKeyOI;
+ private PrimitiveObjectInspector knnItemsValueValueOI;
+
+ private PrimitiveObjectInspector riKeyOI;
+ private PrimitiveObjectInspector riValueOI;
+
+ private PrimitiveObjectInspector rjKeyOI;
+ private PrimitiveObjectInspector rjValueOI;
+
+ //--------------------------------------------
+ // hyperparameters
+
+ private double l1;
+ private double l2;
+ private int numIterations;
+
+ //--------------------------------------------
+ // model parameters and else
+
+ /** item-item weight matrix */
+ private transient DoKFloatMatrix _weightMatrix;
+
+ //--------------------------------------------
+ // caching for each item i
+
+ private int _previousItemId;
+
+ @Nullable
+ private transient Int2FloatOpenHashTable _ri;
+ @Nullable
+ private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi;
+ /** The number of elements in kNNi */
+ @Nullable
+ private transient MutableInt _nnzKNNi;
+
+ //--------------------------------------------
+ // variables for iteration supports
+
+ /** item-user matrix holding the input data */
+ @Nullable
+ private transient DoKFloatMatrix _dataMatrix;
+
+ // used to store KNN data into temporary file for iterative training
+ private transient NioStatefullSegment _fileIO;
+ private transient ByteBuffer _inputBuf;
+
+ private ConversionState _cvState;
+ private long _observedTrainingExamples;
+
+ //--------------------------------------------
+
+ public SlimUDTF() {}
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for -help option
+ String rawArgs = HiveUtils.getConstString(argOIs[0]);
+ parseOptions(rawArgs);
+ }
+
+ if (numArgs != 5 && numArgs != 6) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 5 or 6 arguments: int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
+
+ this.riOI = HiveUtils.asMapOI(argOIs[1]);
+ this.riKeyOI = HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
+ this.riValueOI = HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
+
+ this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
+ this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
+ this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
+ this.knnItemsValueKeyOI = HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
+ this.knnItemsValueValueOI = HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
+
+ this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
+
+ this.rjOI = HiveUtils.asMapOI(argOIs[4]);
+ this.rjKeyOI = HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
+ this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
+
+ processOptions(argOIs);
+
+ this._observedTrainingExamples = 0L;
+ this._previousItemId = Integer.MIN_VALUE;
+ this._weightMatrix = null;
+ this._dataMatrix = null;
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("j");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("nn");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("w");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("l1", "l1coefficient", true,
+ "Coefficient for l1 regularizer [default: 0.001]");
+ opts.addOption("l2", "l2coefficient", true,
+ "Coefficient for l2 regularizer [default: 0.0005]");
+ opts.addOption("iters", "iterations", true,
+ "The number of iterations for coordinate descent [default: 30]");
+ opts.addOption("disable_cv", "disable_cvtest", false,
+ "Whether to disable convergence check [default: enabled]");
+ opts.addOption("cv_rate", "convergence_rate", true,
+ "Threshold to determine convergence [default: 0.005]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ CommandLine cl = null;
+ double l1 = 0.001d;
+ double l2 = 0.0005d;
+ int numIterations = 30;
+ boolean conversionCheck = true;
+ double cv_rate = 0.005d;
+
+ if (argOIs.length >= 6) {
+ String rawArgs = HiveUtils.getConstString(argOIs[5]);
+ cl = parseOptions(rawArgs);
+
+ l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
+ if (l1 < 0.d) {
+ throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1);
+ }
+
+ l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
+ if (l2 < 0.d) {
+ throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2);
+ }
+
+ numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
+ if (numIterations <= 0) {
+ throw new UDFArgumentException("Argument `int iters` must be greater than 0: "
+ + numIterations);
+ }
+
+ conversionCheck = !cl.hasOption("disable_cvtest");
+
+ cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate);
+ if (cv_rate <= 0) {
+ throw new UDFArgumentException(
+ "Argument `double cv_rate` must be greater than 0.0: " + cv_rate);
+ }
+ }
+
+ this.l1 = l1;
+ this.l2 = l2;
+ this.numIterations = numIterations;
+ this._cvState = new ConversionState(conversionCheck, cv_rate);
+
+ return cl;
+ }
+
+ @Override
+ public void process(@Nonnull Object[] args) throws HiveException {
+ if (_weightMatrix == null) {// initialize variables
+ this._weightMatrix = new DoKFloatMatrix();
+ if (numIterations >= 2) {
+ this._dataMatrix = new DoKFloatMatrix();
+ }
+ this._nnzKNNi = new MutableInt();
+ }
+
+ final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], itemIOI);
+
+ if (itemI != _previousItemId || _ri == null) {
+ // cache Ri and kNNi
+ this._ri = int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, riValueOI, _dataMatrix,
+ _ri);
+ this._kNNi = kNNentries(args[2], knnItemsOI, knnItemsKeyOI, knnItemsValueOI,
+ knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
+
+ final int numKNNItems = _nnzKNNi.getValue();
+ if (numIterations >= 2 && numKNNItems >= 1) {
+ recordTrainingInput(itemI, _kNNi, numKNNItems);
+ }
+ this._previousItemId = itemI;
+ }
+
+ int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
+ Int2FloatOpenHashTable rj = int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI,
+ _dataMatrix);
+
+ train(itemI, _ri, _kNNi, itemJ, rj);
+ _observedTrainingExamples++;
+ }
+
+ private void recordTrainingInput(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int numKNNItems)
+ throws HiveException {
+ ByteBuffer buf = this._inputBuf;
+ NioStatefullSegment dst = this._fileIO;
+
+ if (buf == null) {
+ // invoke only at task node (initialize is also invoked in compilation)
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_slim", ".sgmt"); // to save KNN data
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException("Cannot write a temporary file: "
+ + file.getAbsolutePath());
+ }
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ }
+
+ this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 1024); // 8MB
+ this._fileIO = dst = new NioStatefullSegment(file, false);
+ }
+
+ int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * knnItems.size()
+ + (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
+ int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
+
+ int remain = buf.remaining();
+ if (remain < requiredBytes) {
+ writeBuffer(buf, dst);
+ }
+
+ buf.putInt(recordBytes);
+ buf.putInt(itemI);
+ buf.putInt(knnItems.size());
+
+ final IMapIterator<Int2FloatOpenHashTable> entries = knnItems.entries();
+ while (entries.next() != -1) {
+ int user = entries.getKey();
+ buf.putInt(user);
+
+ Int2FloatOpenHashTable ru = entries.getValue();
+ buf.putInt(ru.size());
+ final Int2FloatOpenHashTable.IMapIterator itor = ru.entries();
+ while (itor.next() != -1) {
+ buf.putInt(itor.getKey());
+ buf.putFloat(itor.getValue());
+ }
+ }
+ }
+
+ private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
+ @Nonnull final NioStatefullSegment dst) throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ private void train(final int itemI, @Nonnull final Int2FloatOpenHashTable ri,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi, final int itemJ,
+ @Nonnull final Int2FloatOpenHashTable rj) {
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = rj.size();
+ if (N == 0) {
+ return;
+ }
+
+ double gradSum = 0.d;
+ double rateSum = 0.d;
+ double lossSum = 0.d;
+
+ final Int2FloatOpenHashTable.IMapIterator itor = rj.entries();
+ while (itor.next() != -1) {
+ int user = itor.getKey();
+ double ruj = itor.getValue();
+ double rui = ri.get(user, 0.f);
+
+ double eui = rui - predict(user, itemI, kNNi, itemJ, W);
+ gradSum += ruj * eui;
+ rateSum += ruj * ruj;
+ lossSum += eui * eui;
+ }
+
+ gradSum /= N;
+ rateSum /= N;
+
+ double wij = W.get(itemI, itemJ, 0.d);
+ double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
+ _cvState.incrLoss(loss);
+
+ W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
+ }
+
+ private void train(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int itemJ) {
+ final DoKFloatMatrix A = _dataMatrix;
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = A.numColumns(itemJ);
+ if (N == 0) {
+ return;
+ }
+
+ final MutableDouble mutableGradSum = new MutableDouble(0.d);
+ final MutableDouble mutableRateSum = new MutableDouble(0.d);
+ final MutableDouble mutableLossSum = new MutableDouble(0.d);
+
+ A.eachNonZeroInRow(itemJ, new VectorProcedure() {
+ @Override
+ public void apply(int user, double ruj) {
+ double rui = A.get(itemI, user, 0.d);
+ double eui = rui - predict(user, itemI, knnItems, itemJ, W);
+
+ mutableGradSum.addValue(ruj * eui);
+ mutableRateSum.addValue(ruj * ruj);
+ mutableLossSum.addValue(eui * eui);
+ }
+ });
+
+ double gradSum = mutableGradSum.getValue() / N;
+ double rateSum = mutableRateSum.getValue() / N;
+
+ double wij = W.get(itemI, itemJ, 0.d);
+ double loss = mutableLossSum.getValue() / N + 0.5 * l2 * wij * wij + l1 * wij;
+ _cvState.incrLoss(loss);
+
+ W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
+ }
+
+ private static double predict(final int user, final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems,
+ final int excludeIndex, @Nonnull final DoKFloatMatrix weightMatrix) {
+ final Int2FloatOpenHashTable kNNu = knnItems.get(user);
+ if (kNNu == null) {
+ return 0.d;
+ }
+
+ double pred = 0.d;
+ final Int2FloatOpenHashTable.IMapIterator itor = kNNu.entries();
+ while (itor.next() != -1) {
+ final int itemK = itor.getKey();
+ if (itemK == excludeIndex) {
+ continue;
+ }
+ float ruk = itor.getValue();
+ pred += ruk * weightMatrix.get(itemI, itemK, 0.d);
+ }
+ return pred;
+ }
+
+ private static double getUpdateTerm(final double gradSum, final double rateSum,
+ final double l1, final double l2) {
+ double update = 0.d;
+ if (Math.abs(gradSum) > l1) {
+ if (gradSum > 0.d) {
+ update = (gradSum - l1) / (rateSum + l2);
+ } else {
+ update = (gradSum + l1) / (rateSum + l2);
+ }
+ // non-negative constraints
+ if (update < 0.d) {
+ update = 0.d;
+ }
+ }
+ return update;
+ }
+
+ @Override
+ public void close() throws HiveException {
+ finalizeTraining();
+ forwardModel();
+ this._weightMatrix = null;
+ }
+
+ @VisibleForTesting
+ void finalizeTraining() throws HiveException {
+ if (numIterations > 1) {
+ this._ri = null;
+ this._kNNi = null;
+
+ runIterativeTraining();
+
+ this._dataMatrix = null;
+ }
+ }
+
+ private void runIterativeTraining() throws HiveException {
+ final ByteBuffer buf = this._inputBuf;
+ final NioStatefullSegment dst = this._fileIO;
+ assert (buf != null);
+ assert (dst != null);
+
+ final Reporter reporter = getReporter();
+ final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+ "hivemall.recommend.slim$Counter", "iteration");
+
+ try {
+ if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+ if (buf.position() == 0) {
+ return; // no training example
+ }
+ buf.flip();
+ for (int iter = 2; iter < numIterations; iter++) {
+ _cvState.next();
+ reportProgress(reporter);
+ setCounterValue(iterCounter, iter);
+
+ while (buf.remaining() > 0) {
+ int recordBytes = buf.getInt();
+ assert (recordBytes > 0) : recordBytes;
+ replayTrain(buf);
+ }
+ buf.rewind();
+ if (_cvState.isConverged(_observedTrainingExamples)) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + _cvState.getCurrentIteration()
+ + " iterations of "
+ + NumberUtils.formatNumber(_observedTrainingExamples)
+ + " training examples on memory (thus "
+ + NumberUtils.formatNumber(_observedTrainingExamples
+ * _cvState.getCurrentIteration()) + " training updates in total) ");
+
+ } else { // read training examples in the temporary file and invoke train for each example
+ // write KNNi in buffer to a temporary file
+ if (buf.remaining() > 0) {
+ writeBuffer(buf, dst);
+ }
+
+ try {
+ dst.flush();
+ } catch (IOException e) {
+ throw new HiveException("Failed to flush a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+
+ if (logger.isInfoEnabled()) {
+ File tmpFile = dst.getFile();
+ logger.info("Wrote KNN entries of axis items to a temporary file for iterative training: "
+ + tmpFile.getAbsolutePath()
+ + " ("
+ + FileUtils.prettyFileSize(tmpFile)
+ + ")");
+ }
+
+ // run iterations
+ for (int iter = 2; iter < numIterations; iter++) {
+ _cvState.next();
+ setCounterValue(iterCounter, iter);
+
+ buf.clear();
+ dst.resetPosition();
+ while (true) {
+ reportProgress(reporter);
+ // load a KNNi to a buffer in the temporary file
+ final int bytesRead;
+ try {
+ bytesRead = dst.read(buf);
+ } catch (IOException e) {
+ throw new HiveException("Failed to read a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (bytesRead == 0) { // reached file EOF
+ break;
+ }
+ assert (bytesRead > 0) : bytesRead;
+
+ // reads training examples from a buffer
+ buf.flip();
+ int remain = buf.remaining();
+ if (remain < SizeOf.INT) {
+ throw new HiveException("Illegal file format was detected");
+ }
+ while (remain >= SizeOf.INT) {
+ int pos = buf.position();
+ int recordBytes = buf.getInt();
+ remain -= SizeOf.INT;
+ if (remain < recordBytes) {
+ buf.position(pos);
+ break;
+ }
+
+ replayTrain(buf);
+ remain -= recordBytes;
+ }
+ buf.compact();
+ }
+ if (_cvState.isConverged(_observedTrainingExamples)) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + _cvState.getCurrentIteration()
+ + " iterations of "
+ + NumberUtils.formatNumber(_observedTrainingExamples)
+ + " training examples on memory and KNNi data on secondary storage (thus "
+ + NumberUtils.formatNumber(_observedTrainingExamples
+ * _cvState.getCurrentIteration()) + " training updates in total) ");
+
+ }
+ } catch (Throwable e) {
+ throw new HiveException("Exception caused in the iterative training", e);
+ } finally {
+ // delete the temporary file and release resources
+ try {
+ dst.close(true);
+ } catch (IOException e) {
+ throw new HiveException("Failed to close a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ this._inputBuf = null;
+ this._fileIO = null;
+ }
+ }
+
+ private void replayTrain(@Nonnull final ByteBuffer buf) {
+ final int itemI = buf.getInt();
+ final int knnSize = buf.getInt();
+
+ final IntOpenHashTable<Int2FloatOpenHashTable> knnItems = new IntOpenHashTable<>(1024);
+ final Set<Integer> pairItems = new HashSet<>();
+ for (int i = 0; i < knnSize; i++) {
+ int user = buf.getInt();
+ int ruSize = buf.getInt();
+ Int2FloatOpenHashTable ru = new Int2FloatOpenHashTable(ruSize);
+ ru.defaultReturnValue(0.f);
+
+ for (int j = 0; j < ruSize; j++) {
+ int itemK = buf.getInt();
+ pairItems.add(itemK);
+ float ruk = buf.getFloat();
+ ru.put(itemK, ruk);
+ }
+ knnItems.put(user, ru);
+ }
+
+ for (int itemJ : pairItems) {
+ train(itemI, knnItems, itemJ);
+ }
+ }
+
+ private void forwardModel() throws HiveException {
+ final IntWritable f0 = new IntWritable(); // i
+ final IntWritable f1 = new IntWritable(); // nn
+ final FloatWritable f2 = new FloatWritable(); // w
+ final Object[] forwardObj = new Object[] {f0, f1, f2};
+
+ final MutableObject<HiveException> catched = new MutableObject<>();
+ _weightMatrix.eachNonZeroCell(new VectorProcedure() {
+ @Override
+ public void apply(int i, int j, float value) {
+ if (value == 0.f) {
+ return;
+ }
+ f0.set(i);
+ f1.set(j);
+ f2.set(value);
+ try {
+ forward(forwardObj);
+ } catch (HiveException e) {
+ catched.setIfAbsent(e);
+ }
+ }
+ });
+ HiveException ex = catched.get();
+ if (ex != null) {
+ throw ex;
+ }
+ logger.info("Forwarded SLIM's weights matrix");
+ }
+
+ @Nonnull
+ private static IntOpenHashTable<Int2FloatOpenHashTable> kNNentries(
+ @Nonnull final Object kNNiObj, @Nonnull final MapObjectInspector knnItemsOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsKeyOI,
+ @Nonnull final MapObjectInspector knnItemsValueOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsValueValueOI,
+ @Nullable IntOpenHashTable<Int2FloatOpenHashTable> knnItems,
+ @Nonnull final MutableInt nnzKNNi) {
+ if (knnItems == null) {
+ knnItems = new IntOpenHashTable<>(1024);
+ } else {
+ knnItems.clear();
+ }
+
+ int numElementOfKNNItems = 0;
+ for (Map.Entry<?, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) {
+ int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI);
+ Int2FloatOpenHashTable ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()),
+ knnItemsValueKeyOI, knnItemsValueValueOI);
+ knnItems.put(user, ru);
+ numElementOfKNNItems += ru.size();
+ }
+
+ nnzKNNi.setValue(numElementOfKNNItems);
+ return knnItems;
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(@Nonnull final Map<?, ?> map,
+ @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI) {
+ final Int2FloatOpenHashTable result = new Int2FloatOpenHashTable(map.size());
+ result.defaultReturnValue(0.f);
+
+ for (Map.Entry<?, ?> entry : map.entrySet()) {
+ float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
+ if (v == 0.f) {
+ continue;
+ }
+ int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
+ result.put(k, v);
+ }
+
+ return result;
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(final int item,
+ @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI,
+ @Nullable final DoKFloatMatrix dataMatrix) {
+ return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null);
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(final int item,
+ @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI,
+ @Nullable final DoKFloatMatrix dataMatrix, @Nullable Int2FloatOpenHashTable dst) {
+ if (dst == null) {
+ dst = new Int2FloatOpenHashTable(map.size());
+ dst.defaultReturnValue(0.f);
+ } else {
+ dst.clear();
+ }
+
+ for (Map.Entry<?, ?> entry : map.entrySet()) {
+ float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
+ if (rating == 0.f) {
+ continue;
+ }
+ int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
+ dst.put(user, rating);
+ if (dataMatrix != null) {
+ dataMatrix.set(item, user, rating);
+ }
+ }
+
+ return dst;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
new file mode 100644
index 0000000..3b5585e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public class Int2DoubleOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected double defaultReturnValue = -1.d;
+
+ protected int[] _keys;
+ protected double[] _values;
+ protected byte[] _states;
+
+ protected Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = new double[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Int2DoubleOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Int2DoubleOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(double v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final int key) {
+ return findKey(key) >= 0;
+ }
+
+ /**
+ * @return -1.d if not found
+ */
+ public double get(final int key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public double get(final int key, final double defaultValue) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public double put(final int key, final double value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public double remove(final int key) {
+ final int[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final int[] newkeys = new int[newCapacity];
+ final double[] newValues = new double[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ final int k = _keys[i];
+ final double v = _values[i];
+ final int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(int key) {
+ return key & 0x7fffffff;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeInt(i.getKey());
+ out.writeDouble(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ int keylen = in.readInt();
+ int[] keys = new int[keylen];
+ double[] values = new double[keylen];
+ byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ int k = in.readInt();
+ double v = in.readDouble();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public double getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public double getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
index e9b5c8a..22de115 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
@@ -90,23 +90,27 @@ public class Int2FloatOpenHashTable implements Externalizable {
this.defaultReturnValue = v;
}
- public boolean containsKey(int key) {
+ public boolean containsKey(final int key) {
return findKey(key) >= 0;
}
/**
* @return -1.f if not found
*/
- public float get(int key) {
- int i = findKey(key);
+ public float get(final int key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public float get(final int key, final float defaultValue) {
+ final int i = findKey(key);
if (i < 0) {
- return defaultReturnValue;
+ return defaultValue;
}
return _values[i];
}
- public float put(int key, float value) {
- int hash = keyHash(key);
+ public float put(final int key, final float value) {
+ final int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
@@ -116,9 +120,9 @@ public class Int2FloatOpenHashTable implements Externalizable {
keyIdx = hash % keyLength;
}
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
+ final int[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
if (states[keyIdx] == FULL) {// double hashing
if (keys[keyIdx] == key) {
@@ -127,7 +131,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
return old;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -151,8 +155,8 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
/** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
if (stat == FREE) {
return true;
}
@@ -163,7 +167,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
/** @return expanded or not */
- protected boolean preAddEntry(int index) {
+ protected boolean preAddEntry(final int index) {
if ((_used + 1) >= _threshold) {// too filled
int newCapacity = Math.round(_keys.length * _growFactor);
ensureCapacity(newCapacity);
@@ -172,19 +176,19 @@ public class Int2FloatOpenHashTable implements Externalizable {
return false;
}
- protected int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
return keyIdx;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -201,13 +205,13 @@ public class Int2FloatOpenHashTable implements Externalizable {
return -1;
}
- public float remove(int key) {
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
+ public float remove(final int key) {
+ final int[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
@@ -217,7 +221,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
return old;
}
// second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -242,6 +246,9 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
public void clear() {
+ if (_used == 0) {
+ return; // no need to clear
+ }
Arrays.fill(_states, FREE);
this._used = 0;
}
@@ -274,21 +281,21 @@ public class Int2FloatOpenHashTable implements Externalizable {
this._threshold = Math.round(prime * _loadFactor);
}
- private void rehash(int newCapacity) {
+ private void rehash(final int newCapacity) {
int oldCapacity = _keys.length;
if (newCapacity <= oldCapacity) {
throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
}
- int[] newkeys = new int[newCapacity];
- float[] newValues = new float[newCapacity];
- byte[] newStates = new byte[newCapacity];
+ final int[] newkeys = new int[newCapacity];
+ final float[] newValues = new float[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
int used = 0;
for (int i = 0; i < oldCapacity; i++) {
if (_states[i] == FULL) {
used++;
int k = _keys[i];
float v = _values[i];
- int hash = keyHash(k);
+ final int hash = keyHash(k);
int keyIdx = hash % newCapacity;
if (newStates[keyIdx] == FULL) {// second hashing
int decr = 1 + (hash % (newCapacity - 2));
@@ -310,7 +317,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
this._used = used;
}
- private static int keyHash(int key) {
+ private static int keyHash(final int key) {
return key & 0x7fffffff;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
index 8e87fce..73431d1 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
@@ -77,7 +77,10 @@ public final class Int2IntOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
- public Int2IntOpenHashTable() {// required for serialization
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Int2IntOpenHashTable() {
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
index dbade74..1c90ae0 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
@@ -58,7 +58,10 @@ public final class IntOpenHashTable<V> implements Externalizable {
protected V[] _values;
protected byte[] _states;
- public IntOpenHashTable() {} // for Externalizable
+ /**
+ * Only for {@link Externalizable}
+ */
+ public IntOpenHashTable() {}
public IntOpenHashTable(int size) {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
index b4356ff..115571e 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
@@ -78,6 +78,9 @@ public final class Long2DoubleOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
+ /**
+ * Only for {@link Externalizable}
+ */
public Long2DoubleOpenHashTable() {// required for serialization
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
index 6b0ab59..ba2de76 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
@@ -78,7 +78,10 @@ public final class Long2FloatOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
- public Long2FloatOpenHashTable() {// required for serialization
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Long2FloatOpenHashTable() {
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
}
@@ -113,7 +116,23 @@ public final class Long2FloatOpenHashTable implements Externalizable {
return _values[index];
}
+ public float _set(final int index, final float value) {
+ float old = _values[index];
+ _values[index] = value;
+ return old;
+ }
+
+ public float _remove(final int index) {
+ _states[index] = REMOVED;
+ --_used;
+ return _values[index];
+ }
+
public float put(final long key, final float value) {
+ return put(key, value, defaultReturnValue);
+ }
+
+ public float put(final long key, final float value, final float defaultValue) {
final int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
@@ -155,7 +174,7 @@ public final class Long2FloatOpenHashTable implements Externalizable {
values[keyIdx] = value;
states[keyIdx] = FULL;
++_used;
- return defaultReturnValue;
+ return defaultValue;
}
/** Return weather the required slot is free for new entry */
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
index 1ca4c40..6445231 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
@@ -77,6 +77,9 @@ public final class Long2IntOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
+ /**
+ * Only for {@link Externalizable}
+ */
public Long2IntOpenHashTable() {// required for serialization
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
index 4599bfc..c16567a 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
@@ -59,7 +59,10 @@ public final class OpenHashTable<K, V> implements Externalizable {
protected V[] _values;
protected byte[] _states;
- public OpenHashTable() {} // for Externalizable
+ /**
+ * Only for {@link Externalizable}
+ */
+ public OpenHashTable() {}
public OpenHashTable(int size) {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
new file mode 100644
index 0000000..bea2a9d
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.lang.mutable;
+
+import javax.annotation.Nullable;
+
+public final class MutableObject<T> {
+
+ @Nullable
+ private T _value;
+
+ public MutableObject() {}
+
+ public MutableObject(@Nullable T obj) {
+ this._value = obj;
+ }
+
+ public boolean isSet() {
+ return _value != null;
+ }
+
+ @Nullable
+ public T get() {
+ return _value;
+ }
+
+ public void set(@Nullable T obj) {
+ this._value = obj;
+ }
+
+ public void setIfAbsent(@Nullable T obj) {
+ if (_value == null) {
+ this._value = obj;
+ }
+ }
+
+ @Override
+ public boolean equals(@Nullable Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ MutableObject<?> other = (MutableObject<?>) obj;
+ if (_value == null) {
+ if (other._value != null) {
+ return false;
+ }
+ }
+ return _value.equals(other._value);
+ }
+
+ @Override
+ public int hashCode() {
+ return _value == null ? 0 : _value.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return _value == null ? "null" : _value.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index ee533dc..71d4c29 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -43,7 +43,7 @@ import javax.annotation.Nullable;
import org.apache.commons.math3.special.Gamma;
public final class MathUtils {
- private static final double LOG2 = Math.log(2);
+ public static final double LOG2 = Math.log(2);
private MathUtils() {}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
index 5e8f253..574fc04 100644
--- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
+++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
@@ -103,7 +103,7 @@ public class BinaryResponsesMeasuresTest {
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth,
- rankedList.size());
+ rankedList.size());
Assert.assertEquals(1.0d, actual, 0.0001d);
Collections.reverse(rankedList);
@@ -115,6 +115,22 @@ public class BinaryResponsesMeasuresTest {
Assert.assertEquals(0.0d, actual, 0.0001d);
}
+ public void testHit() {
+ List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
+ List<Integer> groundTruth = Arrays.asList(1, 2, 4);
+
+ double actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, rankedList.size());
+ Assert.assertEquals(1.d, actual, 0.0001d);
+
+ actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.d, actual, 0.0001d);
+
+ // not hitting case
+ rankedList = Arrays.asList(5, 6);
+ actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2);
+ Assert.assertEquals(0.d, actual, 0.0001d);
+ }
+
@Test
public void testAP() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
index 6a7cc9d..96ac030 100644
--- a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
+++ b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
@@ -18,12 +18,12 @@
*/
package hivemall.evaluation;
-import org.junit.Assert;
-import org.junit.Test;
-
import java.util.Arrays;
import java.util.List;
+import org.junit.Assert;
+import org.junit.Test;
+
public class GradedResponsesMeasuresTest {
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
index decd7df..af3f024 100644
--- a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -225,7 +225,6 @@ public class MatrixBuilderTest {
Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d);
}
-
@Test
public void testDoKMatrixFromLibSVM() {
Matrix matrix = dokMatrixFromLibSVM();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
new file mode 100644
index 0000000..c9e6afd
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoKFloatMatrixTest {
+
+ @Test
+ public void testGetSet() {
+ DoKFloatMatrix matrix = new DoKFloatMatrix();
+ Random rnd = new Random(43);
+
+ for (int i = 0; i < 1000; i++) {
+ int row = Math.abs(rnd.nextInt());
+ int col = Math.abs(rnd.nextInt());
+ double v = rnd.nextDouble();
+ matrix.set(row, col, v);
+ Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/recommend/SlimUDTFTest.java b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
new file mode 100644
index 0000000..00b78f0
--- /dev/null
+++ b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Test;
+
+public class SlimUDTFTest {
+ @Test
+ public void testAllSamples() throws HiveException {
+ SlimUDTF slim = new SlimUDTF();
+ ObjectInspector itemIOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+ ObjectInspector itemJOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+
+ ObjectInspector itemIRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
+ ObjectInspector itemJRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
+ ObjectInspector topKRatesOfIOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector));
+ ObjectInspector optionArgumentOI = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-l2 0.01 -l1 0.01");
+
+ ObjectInspector[] argOIs = {itemIOI, itemIRatesOI, topKRatesOfIOI, itemJOI, itemJRatesOI,
+ optionArgumentOI};
+
+ slim.initialize(argOIs);
+ int numUser = 4;
+ int numItem = 5;
+
+ float[][] data = { {1.f, 4.f, 0.f, 0.f, 0.f}, {0.f, 3.f, 0.f, 1.f, 2.f},
+ {2.f, 2.f, 0.f, 0.f, 3.f}, {0.f, 1.f, 1.f, 0.f, 0.f}};
+
+ for (int i = 0; i < numItem; i++) {
+ Map<Integer, Float> Ri = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ if (data[u][i] != 0.) {
+ Ri.put(u, data[u][i]);
+ }
+ }
+
+ // most similar data
+ Map<Integer, Map<Integer, Float>> knnRatesOfI = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ Map<Integer, Float> Ru = new HashMap<>();
+ for (int k = 0; k < numItem; k++) {
+ if (k == i)
+ continue;
+ Ru.put(k, data[u][k]);
+ }
+ knnRatesOfI.put(u, Ru);
+ }
+
+ for (int j = 0; j < numItem; j++) {
+ if (i == j)
+ continue;
+ Map<Integer, Float> Rj = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ if (data[u][j] != 0.) {
+ Rj.put(u, data[u][j]);
+ }
+ }
+
+ Object[] args = {i, Ri, knnRatesOfI, j, Rj};
+ slim.process(args);
+ }
+ }
+ slim.finalizeTraining();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 3d640f8..8b76a7f 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -155,6 +155,7 @@
* [Item-based Collaborative Filtering](recommend/movielens_cf.md)
* [Matrix Factorization](recommend/movielens_mf.md)
* [Factorization Machine](recommend/movielens_fm.md)
+ * [SLIM for Fast Top-K Recommendation](recommend/movielens_slim.md)
* [10-fold Cross Validation (Matrix Factorization)](recommend/movielens_cv.md)
## Part X - Anomaly Detection
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/item_based_cf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/item_based_cf.md b/docs/gitbook/recommend/item_based_cf.md
index 053b225..dcd4f57 100644
--- a/docs/gitbook/recommend/item_based_cf.md
+++ b/docs/gitbook/recommend/item_based_cf.md
@@ -325,7 +325,7 @@ similarity as (
o.other,
cosine_similarity(t1.feature_vector, t2.feature_vector) as similarity
from
- cooccurrence_top100 o 
+ cooccurrence_top100 o
-- cooccurrence_upper_triangular o
JOIN item_features t1 ON (o.itemid = t1.itemid)
JOIN item_features t2 ON (o.other = t2.itemid)
@@ -652,7 +652,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion
item_features f
left outer join item_magnitude m
),
-similarity as ( -- reduce (i.e., sum up) mappers' partial results
+similarity as (
+ -- reduce (i.e., sum up) mappers' partial results
select
itemid,
other,
@@ -702,7 +703,8 @@ partial_result as (
item_features f
left outer join item_magnitude m
),
-similarity_upper_triangular as ( -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted
+similarity_upper_triangular as (
+ -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted
select
itemid,
other,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_cf.md b/docs/gitbook/recommend/movielens_cf.md
index 1cf5aee..0602611 100644
--- a/docs/gitbook/recommend/movielens_cf.md
+++ b/docs/gitbook/recommend/movielens_cf.md
@@ -66,7 +66,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion
movie_features f
left outer join movie_magnitude m
),
-similarity as ( -- reduce (i.e., sum up) mappers' partial results
+similarity as (
+ -- reduce (i.e., sum up) mappers' partial results
select
movieid,
other,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cv.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_cv.md b/docs/gitbook/recommend/movielens_cv.md
index 6ac54c7..80c0d19 100644
--- a/docs/gitbook/recommend/movielens_cv.md
+++ b/docs/gitbook/recommend/movielens_cv.md
@@ -17,7 +17,7 @@
under the License.
-->
-[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validationk-fold cross validation) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation) to evaluate prediction performance.
+[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29#k-fold_cross-validation) to evaluate prediction performance.
*Caution:* Matrix factorization is supported in Hivemall v0.3 or later.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_fm.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_fm.md b/docs/gitbook/recommend/movielens_fm.md
index 64039fe..d3d2c82 100644
--- a/docs/gitbook/recommend/movielens_fm.md
+++ b/docs/gitbook/recommend/movielens_fm.md
@@ -19,6 +19,8 @@
_Caution: Factorization Machine is supported from Hivemall v0.4 or later._
+<!-- toc -->
+
# Data preparation
First of all, please create `ratings` table described in [this article](../recommend/movielens_dataset.html).
@@ -89,7 +91,7 @@ set hivevar:factor=10;
set hivevar:iters=50;
```
-## Build a prediction mdoel by Factorization Machine
+## Build a prediction model by Factorization Machine
```sql
drop table fm_model;
[3/3] incubator-hivemall git commit: Close #117,
Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning
recommendation algorithm
Posted by my...@apache.org.
Close #117, Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning recommendation algorithm
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/995b9a88
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/995b9a88
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/995b9a88
Branch: refs/heads/master
Commit: 995b9a885f6538138935dbf0fe9aae051ec47f9e
Parents: c2b9578
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Authored: Thu Sep 28 12:16:17 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Sep 28 12:16:45 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/evaluation/AUCUDAF.java | 17 +-
.../evaluation/BinaryResponsesMeasures.java | 37 +-
.../evaluation/GradedResponsesMeasures.java | 16 +-
.../java/hivemall/evaluation/HitRateUDAF.java | 262 +++++++
.../main/java/hivemall/evaluation/MAPUDAF.java | 19 +-
.../main/java/hivemall/evaluation/MRRUDAF.java | 19 +-
.../main/java/hivemall/evaluation/NDCGUDAF.java | 17 +-
.../java/hivemall/evaluation/PrecisionUDAF.java | 24 +-
.../java/hivemall/evaluation/RecallUDAF.java | 19 +-
.../hivemall/math/matrix/sparse/CSCMatrix.java | 2 +
.../hivemall/math/matrix/sparse/CSRMatrix.java | 4 +-
.../math/matrix/sparse/DoKFloatMatrix.java | 368 +++++++++
.../hivemall/math/matrix/sparse/DoKMatrix.java | 34 +-
.../hivemall/math/vector/VectorProcedure.java | 6 +
.../hivemall/mf/BPRMatrixFactorizationUDTF.java | 3 +-
.../mf/OnlineMatrixFactorizationUDTF.java | 7 +-
.../main/java/hivemall/recommend/SlimUDTF.java | 759 +++++++++++++++++++
.../maps/Int2DoubleOpenHashTable.java | 427 +++++++++++
.../maps/Int2FloatOpenHashTable.java | 71 +-
.../collections/maps/Int2IntOpenHashTable.java | 5 +-
.../collections/maps/IntOpenHashTable.java | 5 +-
.../maps/Long2DoubleOpenHashTable.java | 3 +
.../maps/Long2FloatOpenHashTable.java | 23 +-
.../collections/maps/Long2IntOpenHashTable.java | 3 +
.../utils/collections/maps/OpenHashTable.java | 5 +-
.../utils/lang/mutable/MutableObject.java | 83 ++
.../java/hivemall/utils/math/MathUtils.java | 2 +-
.../evaluation/BinaryResponsesMeasuresTest.java | 18 +-
.../evaluation/GradedResponsesMeasuresTest.java | 6 +-
.../hivemall/math/matrix/MatrixBuilderTest.java | 1 -
.../math/matrix/sparse/DoKFloatMatrixTest.java | 43 ++
.../java/hivemall/recommend/SlimUDTFTest.java | 99 +++
docs/gitbook/SUMMARY.md | 1 +
docs/gitbook/recommend/item_based_cf.md | 8 +-
docs/gitbook/recommend/movielens_cf.md | 3 +-
docs/gitbook/recommend/movielens_cv.md | 2 +-
docs/gitbook/recommend/movielens_fm.md | 4 +-
docs/gitbook/recommend/movielens_slim.md | 589 ++++++++++++++
resources/ddl/define-all-as-permanent.hive | 10 +
resources/ddl/define-all.hive | 10 +
resources/ddl/define-all.spark | 10 +
resources/ddl/define-udfs.td.hql | 2 +
42 files changed, 2916 insertions(+), 130 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/AUCUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
index 7cbdb52..508e36a 100644
--- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
@@ -52,7 +52,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -430,7 +429,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -448,7 +447,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -507,12 +506,12 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
index 7c21849..c3b4f6a 100644
--- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
@@ -45,7 +45,7 @@ public final class BinaryResponsesMeasures {
*/
public static double nDCG(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
double dcg = 0.d;
@@ -92,6 +92,8 @@ public final class BinaryResponsesMeasures {
*/
public static double Precision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
if (rankedList.isEmpty()) {
if (groundTruth.isEmpty()) {
return 1.d;
@@ -99,8 +101,6 @@ public final class BinaryResponsesMeasures {
return 0.d;
}
- Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty
-
int nTruePositive = 0;
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
@@ -135,6 +135,29 @@ public final class BinaryResponsesMeasures {
}
/**
+ * Computes Hit@`recommendSize`
+ *
+ * @param rankedList a list of ranked item IDs (first item is highest-ranked)
+ * @param groundTruth a collection of positive/correct item IDs
+ * @param recommendSize top-`recommendSize` items in `rankedList` are recommended
+ * @return 1.0 if hit 0.0 if no hit
+ */
+ public static double Hit(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
+ Object item_id = rankedList.get(i);
+ if (groundTruth.contains(item_id)) {
+ return 1.d;
+ }
+ }
+
+ return 0.d;
+ }
+
+ /**
* Counts the number of true positives
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
@@ -144,7 +167,7 @@ public final class BinaryResponsesMeasures {
*/
public static int TruePositives(final List<?> rankedList, final List<?> groundTruth,
@Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
int nTruePositive = 0;
@@ -170,7 +193,7 @@ public final class BinaryResponsesMeasures {
*/
public static double ReciprocalRank(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
@@ -193,7 +216,7 @@ public final class BinaryResponsesMeasures {
*/
public static double AveragePrecision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
if (groundTruth.isEmpty()) {
if (rankedList.isEmpty()) {
@@ -231,7 +254,7 @@ public final class BinaryResponsesMeasures {
*/
public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
int nTruePositive = 0, nCorrectPairs = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
index 688ba53..5bbbb7e 100644
--- a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
@@ -18,8 +18,12 @@
*/
package hivemall.evaluation;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.MathUtils;
+
import java.util.List;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
@@ -32,7 +36,7 @@ public final class GradedResponsesMeasures {
private GradedResponsesMeasures() {}
public static double nDCG(@Nonnull final List<Double> recommendTopRelScoreList,
- @Nonnull final List<Double> truthTopRelScoreList, @Nonnull final int recommendSize) {
+ @Nonnull final List<Double> truthTopRelScoreList, @Nonnegative final int recommendSize) {
double dcg = DCG(recommendTopRelScoreList, recommendSize);
double idcg = DCG(truthTopRelScoreList, recommendSize);
return dcg / idcg;
@@ -45,11 +49,15 @@ public final class GradedResponsesMeasures {
* @param recommendSize the number of positive items
* @return DCG
*/
- public static double DCG(final List<Double> topRelScoreList, final int recommendSize) {
+ public static double DCG(@Nonnull final List<Double> topRelScoreList,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
double dcg = 0.d;
- for (int i = 0; i < recommendSize; i++) {
+ final int k = Math.min(topRelScoreList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
double relScore = topRelScoreList.get(i);
- dcg += ((Math.pow(2, relScore) - 1) * Math.log(2)) / Math.log(i + 2);
+ dcg += ((Math.pow(2, relScore) - 1) * MathUtils.LOG2) / Math.log(i + 2);
}
return dcg;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/HitRateUDAF.java b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
new file mode 100644
index 0000000..6df6087
--- /dev/null
+++ b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+package hivemall.evaluation;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.LongWritable;
+
+@Description(
+ name = "hitrate",
+ value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])"
+ + " - Returns HitRate")
+public final class HitRateUDAF extends AbstractGenericUDAFResolver {
+
+ // prevent instantiation
+ private HitRateUDAF() {}
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 2 && typeInfo.length != 3) {
+ throw new UDFArgumentTypeException(typeInfo.length - 1,
+ "_FUNC_ takes two or three arguments");
+ }
+
+ ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(0,
+ "The first argument `array rankItems` is invalid form: " + typeInfo[0]);
+ }
+ ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "The second argument `array correctItems` is invalid form: " + typeInfo[1]);
+ }
+
+ return new HitRateUDAF.Evaluator();
+ }
+
+ public static class Evaluator extends GenericUDAFEvaluator {
+
+ private ListObjectInspector recommendListOI;
+ private ListObjectInspector truthListOI;
+ private PrimitiveObjectInspector recommendSizeOI;
+
+ private StructObjectInspector internalMergeOI;
+ private StructField countField;
+ private StructField sumField;
+
+ public Evaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.recommendListOI = (ListObjectInspector) parameters[0];
+ this.truthListOI = (ListObjectInspector) parameters[1];
+ if (parameters.length == 3) {
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
+ }
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.countField = soi.getStructFieldRef("count");
+ this.sumField = soi.getStructFieldRef("sum");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {// terminate
+ outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("sum");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("count");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public HitRateAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ HitRateAggregationBuffer myAggr = new HitRateAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+
+ List<?> recommendList = recommendListOI.getList(parameters[0]);
+ if (recommendList == null) {
+ recommendList = Collections.emptyList();
+ }
+ List<?> truthList = truthListOI.getList(parameters[1]);
+ if (truthList == null) {
+ return;
+ }
+
+ int recommendSize = recommendList.size();
+ if (parameters.length == 3) {
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
+ }
+
+ myAggr.iterate(recommendList, truthList, recommendSize);
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+
+ Object[] partialResult = new Object[2];
+ partialResult[0] = new DoubleWritable(myAggr.sum);
+ partialResult[1] = new LongWritable(myAggr.count);
+ return partialResult;
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+ throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
+ Object countObj = internalMergeOI.getStructFieldData(partial, countField);
+ double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
+ long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
+
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ myAggr.merge(sum, count);
+ }
+
+ @Override
+ public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ public static final class HitRateAggregationBuffer extends
+ GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+ private double sum;
+ private long count;
+
+ public HitRateAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.sum = 0.d;
+ this.count = 0;
+ }
+
+ void merge(double o_sum, long o_count) {
+ this.sum += o_sum;
+ this.count += o_count;
+ }
+
+ double get() {
+ if (count == 0) {
+ return 0.d;
+ }
+ return sum / count;
+ }
+
+ void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
+ @Nonnegative int recommendSize) {
+ this.sum += BinaryResponsesMeasures.Hit(recommendList, truthList, recommendSize);
+ this.count++;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MAPUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
index 3878684..45e64cb 100644
--- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MRRUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
index f5aba3b..98b8c3d 100644
--- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
index f1ba832..4e4fde6 100644
--- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
@@ -45,7 +45,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -85,7 +84,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -103,7 +102,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -164,12 +163,12 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
boolean isBinary = !HiveUtils.isStructOI(recommendListOI.getListElementObjectInspector());
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
index 93af519..de8a876 100644
--- a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -117,9 +118,10 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
return outputOI;
}
+ @Nonnull
private static StructObjectInspector internalMergeOI() {
- ArrayList<String> fieldNames = new ArrayList<String>();
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
fieldNames.add("sum");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
@@ -159,12 +161,12 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/RecallUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/RecallUDAF.java b/core/src/main/java/hivemall/evaluation/RecallUDAF.java
index fed9f71..30b1712 100644
--- a/core/src/main/java/hivemall/evaluation/RecallUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/RecallUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
index d2232b2..f8eb02f 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -31,6 +31,8 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
+ * Compressed Sparse Column matrix optimized for colum major access.
+ *
* @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
*/
public final class CSCMatrix extends ColumnMajorMatrix {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
index dd89521..805bbd1 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
@@ -29,8 +29,8 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
- * Read-only CSR double Matrix.
- *
+ * Compressed Sparse Row Matrix optimized for row major access.
+ *
* @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
* @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
new file mode 100644
index 0000000..16b4b64
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.annotations.Experimental;
+import hivemall.math.matrix.AbstractMatrix;
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Dictionary Of Keys based sparse matrix.
+ *
+ * This is an efficient structure for constructing a sparse matrix incrementally.
+ */
+@Experimental
+public final class DoKFloatMatrix extends AbstractMatrix {
+
+ @Nonnull
+ private final Long2FloatOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public DoKFloatMatrix() {
+ this(0, 0);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
+ @Nonnegative float sparsity) {
+ super();
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ this.nnz = 0;
+ }
+
+ public DoKFloatMatrix(@Nonnegative int initSize) {
+ super();
+ int initialCapacity = Math.max(initSize, 16384);
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = 0;
+ this.numColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ int count = 0;
+ for (int j = 0; j < numColumns; j++) {
+ long index = index(row, j);
+ if (elements.containsKey(index)) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ double[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ float v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int col = 0; col < numColumns; col++) {
+ long k = index(index, col);
+ final float v = elements.get(k, 0.f);
+ if (v != 0.f) {
+ row.set(col, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ return get(row, col, (float) defaultValue);
+ }
+
+ public float get(@Nonnegative final int row, @Nonnegative final int col,
+ final float defaultValue) {
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ set(row, col, (float) value);
+ }
+
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return;
+ }
+
+ if (elements.put(index, value, 0.f) == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ return getAndSet(row, col, (float) value);
+ }
+
+ public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return 0.f;
+ }
+
+ final float old = elements.put(index, value, 0.f);
+ if (old == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ return old;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ for (int j = 0; j < numColumns; j++) {
+ final long i1 = index(row1, j);
+ final long i2 = index(row2, j);
+
+ final int k1 = elements._findKey(i1);
+ final int k2 = elements._findKey(i2);
+
+ if (k1 >= 0) {
+ if (k2 >= 0) {
+ float v1 = elements._get(k1);
+ float v2 = elements._set(k2, v1);
+ elements._set(k1, v2);
+ } else {// k1>=0 and k2<0
+ float v1 = elements._remove(k1);
+ elements.put(i2, v1);
+ }
+ } else if (k2 >= 0) {// k2>=0 and k1 < 0
+ float v2 = elements._remove(k2);
+ elements.put(i1, v2);
+ } else {//k1<0 and k2<0
+ continue;
+ }
+ }
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, 0.d);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key != -1) {
+ procedure.apply(col);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, 0.d);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
+ if (nnz == 0) {
+ return;
+ }
+ final IMapIterator itor = elements.entries();
+ while (itor.next() != -1) {
+ long k = itor.getKey();
+ int row = Primitives.getHigh(k);
+ int col = Primitives.getLow(k);
+ float value = itor.getValue();
+ procedure.apply(row, col, value);
+ }
+ }
+
+ @Override
+ public RowMajorMatrix toRowMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public ColumnMajorMatrix toColumnMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public DoKMatrixBuilder builder() {
+ return new DoKMatrixBuilder(elements.size());
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
index bcfd152..054d62a 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -26,12 +26,18 @@ import hivemall.math.matrix.builders.DoKMatrixBuilder;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.utils.collections.maps.Long2DoubleOpenHashTable;
+import hivemall.utils.collections.maps.Long2DoubleOpenHashTable.IMapIterator;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
+/**
+ * Dictionary Of Keys based sparse matrix.
+ *
+ * This is an efficient structure for constructing a sparse matrix incrementally.
+ */
@Experimental
public final class DoKMatrix extends AbstractMatrix {
@@ -163,8 +169,6 @@ public final class DoKMatrix extends AbstractMatrix {
@Override
public double get(@Nonnegative final int row, @Nonnegative final int col,
final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
long index = index(row, col);
return elements.get(index, defaultValue);
}
@@ -173,11 +177,11 @@ public final class DoKMatrix extends AbstractMatrix {
public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
checkIndex(row, col);
- if (value == 0.d) {
+ final long index = index(row, col);
+ if (value == 0.d && elements.containsKey(index) == false) {
return;
}
- long index = index(row, col);
if (elements.put(index, value, 0.d) == 0.d) {
nnz++;
this.numRows = Math.max(numRows, row + 1);
@@ -190,8 +194,12 @@ public final class DoKMatrix extends AbstractMatrix {
final double value) {
checkIndex(row, col);
- long index = index(row, col);
- double old = elements.put(index, value, 0.d);
+ final long index = index(row, col);
+ if (value == 0.d && elements.containsKey(index) == false) {
+ return 0.d;
+ }
+
+ final double old = elements.put(index, value, 0.d);
if (old == 0.d) {
nnz++;
this.numRows = Math.max(numRows, row + 1);
@@ -309,6 +317,20 @@ public final class DoKMatrix extends AbstractMatrix {
}
}
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
+ if (nnz == 0) {
+ return;
+ }
+ final IMapIterator itor = elements.entries();
+ while (itor.next() != -1) {
+ long k = itor.getKey();
+ int row = Primitives.getHigh(k);
+ int col = Primitives.getLow(k);
+ double value = itor.getValue();
+ procedure.apply(row, col, value);
+ }
+ }
+
@Override
public RowMajorMatrix toRowMajorMatrix() {
throw new UnsupportedOperationException("Not yet supported");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/vector/VectorProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
index 266c531..3f3c390 100644
--- a/core/src/main/java/hivemall/math/vector/VectorProcedure.java
+++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
@@ -24,6 +24,12 @@ public abstract class VectorProcedure {
public VectorProcedure() {}
+ public void apply(@Nonnegative int i, @Nonnegative int j, float value) {
+ apply(i, j, (double) value);
+ }
+
+ public void apply(@Nonnegative int i, @Nonnegative int j, double value) {}
+
public void apply(@Nonnegative int i, double value) {}
public void apply(@Nonnegative int i, int value) {}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
index 141b261..0f9b5fd 100644
--- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
@@ -512,9 +512,8 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
// write training examples in buffer to a temporary file
if (inputBuf.position() > 0) {
writeBuffer(inputBuf, fileIO, lastWritePos);
- } else if (lastWritePos == 0) {
- return; // no training example
}
+
try {
fileIO.flush();
} catch (IOException e) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
index 66ec60d..ee549c5 100644
--- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
@@ -148,7 +148,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
if (iterations < 1) {
throw new UDFArgumentException(
- "'-iterations' must be greater than or equals to 1: " + iterations);
+ "'-iterations' must be greater than or equal to 1: " + iterations);
}
conversionCheck = !cl.hasOption("disable_cvtest");
convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
@@ -239,7 +239,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
}
int item = PrimitiveObjectInspectorUtils.getInt(args[1], itemOI);
if (item < 0) {
- throw new HiveException("Illegal item index: " + user);
+ throw new HiveException("Illegal item index: " + item);
}
double rating = PrimitiveObjectInspectorUtils.getDouble(args[2], ratingOI);
@@ -505,9 +505,8 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
// write training examples in buffer to a temporary file
if (inputBuf.position() > 0) {
writeBuffer(inputBuf, fileIO, lastWritePos);
- } else if (lastWritePos == 0) {
- return; // no training example
}
+
try {
fileIO.flush();
} catch (IOException e) {