You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/20 12:02:36 UTC
[1/5] incubator-hivemall git commit: Close #66: [HIVEMALL-91]
Implement Online LDA
Repository: incubator-hivemall
Updated Branches:
refs/heads/master bba252ac1 -> e4e1531e1
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
new file mode 100644
index 0000000..b4810a6
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
+
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+
+public class OnlineLDAModelTest {
+ private static final boolean DEBUG = false;
+
+ @Test
+ public void test() {
+ int K = 2;
+ int it = 0;
+ float perplexityPrev;
+ float perplexity = Float.MAX_VALUE;
+
+ OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d);
+
+ String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+ String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"};
+
+ do {
+ perplexityPrev = perplexity;
+ perplexity = 0.f;
+
+ // online (i.e., one-by-one) updating
+ model.train(new String[][] {doc1});
+ perplexity += model.computePerplexity();
+
+ model.train(new String[][] {doc2});
+ perplexity += model.computePerplexity();
+
+ perplexity /= 2.f; // mean perplexity for the 2 docs
+
+ it++;
+ println("Iteration " + it + ": mean perplexity = " + perplexity);
+ } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f);
+
+ SortedMap<Float, List<String>> topicWords;
+
+ println("Topic 0:");
+ println("========");
+ topicWords = model.getTopicWords(0);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ println(e.getKey() + " " + words.get(i));
+ }
+ }
+ println("========");
+
+ println("Topic 1:");
+ println("========");
+ topicWords = model.getTopicWords(1);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ println(e.getKey() + " " + words.get(i));
+ }
+ }
+ println("========");
+
+ int k1, k2;
+ float[] topicDistr = model.getTopicDistribution(doc1);
+ if (topicDistr[0] > topicDistr[1]) {
+ // topic 0 MUST represent doc#1
+ k1 = 0;
+ k2 = 1;
+ } else {
+ k1 = 1;
+ k2 = 0;
+ }
+ Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+ model.getLambda("vegetables", k1) > model.getLambda("flu", k1));
+ Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+ model.getLambda("avocados", k2) > model.getLambda("healthy", k2));
+ }
+
+ @Test
+ public void testPerplexity() {
+ int K = 2;
+ int it = 0;
+ float perplexityPrev;
+ float perplexity = Float.MAX_VALUE;
+
+ OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, 2, 80, 0.8, 1E-5d);
+
+ String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+ String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"};
+
+ do {
+ perplexityPrev = perplexity;
+
+ model.train(new String[][] {doc1, doc2});
+ perplexity = model.computePerplexity();
+
+ it++;
+ } while(Math.abs(perplexityPrev - perplexity) >= 1E-6f);
+
+ println("Iterated " + it + " times, perplexity = " + perplexity);
+
+ // For the same data and hyperparameters,
+ // scikit-learn Python library (implemented based on Matthew D. Hoffman's onlineldavb code)
+ // returns perplexity=15 in a batch setting and perplexity=22 in an online setting.
+ // Hivemall needs to converge to the similar perplexity.
+ Assert.assertTrue("Perplexity SHOULD be in [12, 25]; "
+ + "converged perplexity is too small or large for some reasons",12.f <= perplexity && perplexity <= 25.f);
+ }
+
+ @Test
+ public void testNews20() throws IOException {
+ int K = 20;
+ int numTotalDocs = 2000;
+ int miniBatchSize = 2;
+
+ int cnt, it;
+
+ OnlineLDAModel model = new OnlineLDAModel(K, 1.f / K, 1.f / K, numTotalDocs, 80, 0.8, 1E-3d);
+
+ BufferedReader news20 = readFile("news20-multiclass.gz");
+
+ String[][] docs = new String[K][];
+
+ String line = news20.readLine();
+ List<String> doc = new ArrayList<String>();
+
+ cnt = 0;
+ while (line != null) {
+ StringTokenizer tokens = new StringTokenizer(line, " ");
+
+ int k = Integer.parseInt(tokens.nextToken()) - 1;
+
+ while (tokens.hasMoreTokens()) {
+ doc.add(tokens.nextToken());
+ }
+
+ // store first document in each of K classes
+ if (docs[k] == null) {
+ docs[k] = doc.toArray(new String[doc.size()]);
+ cnt++;
+ }
+
+ if (cnt == K) {
+ break;
+ }
+
+ doc.clear();
+ line = news20.readLine();
+ }
+ println("Stored " + cnt + " docs. Start training w/ mini-batch size: " + miniBatchSize);
+
+ float perplexityPrev;
+ float perplexity = Float.MAX_VALUE;
+
+ it = 0;
+ do {
+ perplexityPrev = perplexity;
+ perplexity = 0.f;
+
+ int head = 0;
+ cnt = 0;
+ while (head < K) {
+ int tail = head + miniBatchSize;
+ model.train(Arrays.copyOfRange(docs, head, tail));
+ perplexity += model.computePerplexity();
+ head = tail;
+ cnt++;
+ println("Processed mini-batch#" + cnt);
+ }
+
+ perplexity /= cnt;
+
+ it++;
+
+ println("Iteration " + it + ": mean perplexity = " + perplexity);
+ } while(Math.abs(perplexityPrev - perplexity) >= 1E-1f);
+
+ Set<Integer> topics = new HashSet<Integer>();
+ for (int k = 0; k < K; k++) {
+ topics.add(findMaxTopic(model.getTopicDistribution(docs[k])));
+ }
+
+ int n = topics.size();
+ Assert.assertTrue("At least 15 documents SHOULD be classified to different topics, "
+ + "but there are only " + n + " unique topics.", n >= 15);
+ }
+
+ private static void println(String msg) {
+ if (DEBUG) {
+ System.out.println(msg);
+ }
+ }
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ // use data stored for KPA UDTF test
+ InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
+
+ @Nonnull
+ private static int findMaxTopic(@Nonnull float[] topicDistr) {
+ int maxIdx = 0;
+ for (int i = 1; i < topicDistr.length; i++) {
+ if (topicDistr[maxIdx] < topicDistr[i]) {
+ maxIdx = i;
+ }
+ }
+ return maxIdx;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 4c6ed1b..78b1faa 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -150,7 +150,11 @@
* [Change-Point Detection using Singular Spectrum Transformation (SST)](anomaly/sst.md)
* [ChangeFinder: Detecting Outlier and Change-Point Simultaneously](anomaly/changefinder.md)
-## Part X - Hivemall on Spark
+## Part X - Clustering
+
+* [Latent Dirichlet Allocation](clustering/lda.md)
+
+## Part XI - Hivemall on Spark
* [Getting Started](spark/getting_started/README.md)
* [Installation](spark/getting_started/installation.md)
@@ -165,7 +169,7 @@
* [Top-k Join processing](spark/misc/topk_join.md)
* [Other utility functions](spark/misc/functions.md)
-## Part X - External References
+## Part XII - External References
* [Hivemall on Apache Spark](https://github.com/maropu/hivemall-spark)
* [Hivemall on Apache Pig](https://github.com/daijyc/hivemall/wiki/PigHome)
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/docs/gitbook/clustering/lda.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/lda.md b/docs/gitbook/clustering/lda.md
new file mode 100644
index 0000000..cc477da
--- /dev/null
+++ b/docs/gitbook/clustering/lda.md
@@ -0,0 +1,195 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Topic modeling is a way to analyze massive documents by clustering them into some ***topics***. In particular, **Latent Dirichlet Allocation** (LDA) is one of the most popular topic modeling techniques; papers introduce the method are as follows:
+
+- D. M. Blei, et al. [Latent Dirichlet Allocation](http://www.jmlr.org/papers/v3/blei03a.html). Journal of Machine Learning Research 3, pp. 993-1022, 2003.
+- M. D. Hoffman, et al. [Online Learning for Latent Dirichlet Allocation](https://papers.nips.cc/paper/3902-online-learning-for-latent-dirichlet-allocation). NIPS 2010.
+
+Hivemall enables you to analyze your data such as, but not limited to, documents based on LDA. This page gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.1 or later.
+
+# Prepare document data
+
+Assume that we already have a table `docs` which contains many documents as string format:
+
+| docid | doc |
+|:---:|:---|
+| 1 | "Fruits and vegetables are healthy." |
+|2 | "I like apples, oranges, and avocados. I do not like the flu or colds." |
+| ... | ... |
+
+Hivemall has several functions which are particularly useful for text processing. More specifically, by using `tokenize()` and `is_stopword()`, you can immediately convert the documents to [bag-of-words](https://en.wikipedia.org/wiki/Bag-of-words_model)-like format:
+
+```sql
+with word_counts as (
+ select
+ docid,
+ feature(word, count(word)) as word_count
+ from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word
+ where
+ not is_stopword(word)
+ group by
+ docid, word
+)
+select docid, collect_set(word_count) as feature
+from word_counts
+group by docid
+;
+```
+
+| docid | feature |
+|:---:|:---|
+|1 | ["fruits:1","healthy:1","vegetables:1"] |
+|2 | ["apples:1","avocados:1","colds:1","flu:1","like:2","oranges:1"] |
+
+> #### Note
+> It should be noted that, as long as your data can be represented as the feature format, LDA can be applied for arbitrary data as a generic clustering technique.
+
+# Building Topic Models and Finding Topic Words
+
+Each feature vector is input to the `train_lda()` function:
+
+```sql
+with word_counts as (
+ select
+ docid,
+ feature(word, count(word)) as word_count
+ from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word
+ where
+ not is_stopword(word)
+ group by
+ docid, word
+)
+select
+ train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+from (
+ select docid, collect_set(word_count) as feature
+ from word_counts
+ group by docid
+ order by docid
+) t
+;
+```
+
+Here, an option `-topic 2` specifies the number of topics we assume in the set of documents.
+
+Notice that `order by docid` ensures building a LDA model precisely in a single node. In case that you like to launch `train_lda` in parallel, following query hopefully returns similar (but might be slightly approximated) result:
+
+```sql
+with word_counts as (
+ -- same as above
+)
+select
+ label, word, avg(lambda) as lambda
+from (
+ select
+ train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+ from (
+ select docid, collect_set(f) as feature
+ from word_counts
+ group by docid
+ ) t1
+) t2
+group by label, word
+order by lambda desc
+;
+```
+
+Eventually, a new table `lda_model` is generated as shown below:
+
+|label | word | lambda |
+|:---:|:---:|:---:|
+|0 | fruits | 0.33372128|
+|0 | vegetables | 0.33272517|
+|0 | healthy | 0.33246377|
+|0 | flu | 2.3617347E-4|
+|0 | apples | 2.1898883E-4|
+|0 | oranges | 1.8161473E-4|
+|0 | like | 1.7666373E-4|
+|0 | avocados | 1.726186E-4|
+|0 | colds | 1.037139E-4|
+|1 | colds | 0.16622013|
+|1 | avocados | 0.16618845|
+|1 | oranges | 0.1661859|
+|1 | like | 0.16618414|
+|1 | apples | 0.16616651|
+|1 | flu | 0.16615893|
+|1 | healthy | 0.0012059759|
+|1 | vegetables | 0.0010818697|
+|1 | fruits | 6.080827E-4|
+
+In the table, `label` indicates a topic index, and `lambda` is a value which represents how each word is likely to characterize a topic. That is, we can say that, in terms of `lambda`, top-N words are the ***topic words*** of a topic.
+
+Obviously, we can observe that topic `0` corresponds to document `1`, and topic `1` represents words in document `2`.
+
+# Predicting Topic Assignments of Documents
+
+Once you have constructed topic models as described before, a function `lda_predict()` allows you to predict topic assignments of documents.
+
+For example, if we consider the `docs` table, the exactly same set of documents as used for training, probability that a document is assigned to a topic can be computed by:
+
+```sql
+with test as (
+ select
+ docid,
+ word,
+ count(word) as value
+ from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word
+ where
+ not is_stopword(word)
+ group by
+ docid, word
+)
+select
+ t.docid,
+ lda_predict(t.word, t.value, m.label, m.lambda, "-topic 2") as probabilities
+from
+ test t
+ JOIN lda_model m ON (t.word = m.word)
+group by
+ t.docid
+;
+```
+
+| docid | probabilities (sorted by probabilities) |
+|:---:|:---|
+|1 | [{"label":0,"probability":0.875},{"label":1,"probability":0.125}]|
+|2 | [{"label":1,"probability":0.9375},{"label":0,"probability":0.0625}]|
+
+Importantly, an option `-topic` should be set to the same value as you set for training.
+
+Since the probabilities are sorted in descending order, a label of the most promising topic is easily obtained as:
+
+```sql
+select docid, probabilities[0].label
+from topic
+;
+```
+
+| docid | label |
+|:---:|:---:|
+| 1 | 0 |
+| 2 | 1 |
+
+Of course, using the different set of documents for prediction is possible. Predicting topic assignments of newly observed documents should be more realistic scenario.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index c6dda03..1eb9c82 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -616,6 +616,16 @@ CREATE FUNCTION changefinder as 'hivemall.anomaly.ChangeFinderUDF' USING JAR '${
DROP FUNCTION IF EXISTS sst;
CREATE FUNCTION sst as 'hivemall.anomaly.SingularSpectrumTransformUDF' USING JAR '${hivemall_jar}';
+--------------------
+-- Topic Modeling --
+--------------------
+
+DROP FUNCTION IF EXISTS train_lda;
+CREATE FUNCTION train_lda as 'hivemall.topicmodel.LDAUDTF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS lda_predict;
+CREATE FUNCTION lda_predict as 'hivemall.topicmodel.LDAPredictUDAF' USING JAR '${hivemall_jar}';
+
----------------------------
-- Smile related features --
----------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 8ea16c1..b503546 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -612,6 +612,16 @@ create temporary function changefinder as 'hivemall.anomaly.ChangeFinderUDF';
drop temporary function if exists sst;
create temporary function sst as 'hivemall.anomaly.SingularSpectrumTransformUDF';
+--------------------
+-- Topic Modeling --
+--------------------
+
+drop temporary function if exists train_lda;
+create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF';
+
+drop temporary function if exists lda_predict;
+create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF';
+
----------------------------
-- Smile related features --
----------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 0172cc8..b5239cf 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -597,6 +597,16 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sst")
sqlContext.sql("CREATE TEMPORARY FUNCTION sst AS 'hivemall.anomaly.SingularSpectrumTransformUDF'")
/**
+ * Topic Modeling
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_lda")
+sqlContext.sql("CREATE TEMPORARY FUNCTION train_lda AS 'hivemall.topicmodel.LDAUDTF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS lda_predict")
+sqlContext.sql("CREATE TEMPORARY FUNCTION lda_predict AS 'hivemall.topicmodel.LDAPredictUDAF'")
+
+/**
* Smile related features
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index cff0913..28d17ff 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -158,6 +158,8 @@ create temporary function guess_attribute_types as 'hivemall.smile.tools.GuessAt
-- since Hivemall v0.5-rc.1
create temporary function changefinder as 'hivemall.anomaly.ChangeFinderUDF';
create temporary function sst as 'hivemall.anomaly.SingularSpectrumTransformUDF';
+create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF';
+create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF';
-- NLP features
create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
[4/5] incubator-hivemall git commit: Removed unused imports and
formatted
Posted by my...@apache.org.
Removed unused imports and formatted
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9669c9d4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9669c9d4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9669c9d4
Branch: refs/heads/master
Commit: 9669c9d4304a8c440aae7cda8201a6116c4edbbf
Parents: 1f98970
Author: myui <yu...@gmail.com>
Authored: Thu Apr 20 21:00:06 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 20 21:00:06 2017 +0900
----------------------------------------------------------------------
.../java/hivemall/evaluation/AUCUDAFTest.java | 86 +++++++-------------
1 file changed, 31 insertions(+), 55 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9669c9d4/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
index df26175..2582ab4 100644
--- a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
+++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
@@ -18,6 +18,8 @@
*/
package hivemall.evaluation;
+import java.util.ArrayList;
+
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
@@ -26,10 +28,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -46,10 +44,8 @@ public class AUCUDAFTest {
auc = new AUCUDAF();
inputOIs = new ObjectInspector[] {
- PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
- PrimitiveObjectInspector.PrimitiveCategory.DOUBLE),
- PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
- PrimitiveObjectInspector.PrimitiveCategory.INT)};
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.DOUBLE),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT)};
evaluator = auc.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
@@ -120,7 +116,7 @@ public class AUCUDAFTest {
Assert.assertEquals(0.8125, agg.get(), 1e-5);
}
- @Test(expected=HiveException.class)
+ @Test(expected = HiveException.class)
public void testAllTruePositive() throws Exception {
final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
final int[] labels = new int[] {1, 1, 1, 1, 1};
@@ -136,7 +132,7 @@ public class AUCUDAFTest {
agg.get();
}
- @Test(expected=HiveException.class)
+ @Test(expected = HiveException.class)
public void testAllFalsePositive() throws Exception {
final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
final int[] labels = new int[] {0, 0, 0, 0, 0};
@@ -236,7 +232,8 @@ public class AUCUDAFTest {
// merge bins
// merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value
- final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}, {2, 0, 1}};
+ final int[][] orders = new int[][] { {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0},
+ {2, 0, 1}};
for (int i = 0; i < orders.length; i++) {
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
evaluator.reset(agg);
@@ -251,28 +248,17 @@ public class AUCUDAFTest {
@Test
public void test100() throws Exception {
- final double[] scores = new double[] {
- 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
- 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8,
- 0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
- 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6,
- 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
- 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
- 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3,
- 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
- 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+ final double[] scores = new double[] {0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+ 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7, 0.7, 0.7,
+ 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 0.6,
+ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4,
+ 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
+ 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
- final int[] labels = new int[] {
- 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
- 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 0, 0, 1, 1, 1, 1, 1, 1, 0,
- 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
- 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,
- 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
- 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+ final int[] labels = new int[] {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,
+ 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
@@ -287,28 +273,17 @@ public class AUCUDAFTest {
@Test
public void testMerge100() throws Exception {
- final double[] scores = new double[] {
- 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
- 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8,
- 0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
- 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6,
- 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
- 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
- 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3,
- 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
- 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
- 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
- final int[] labels = new int[] {
- 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
- 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 0, 0, 1, 1, 1, 1, 1, 1, 0,
- 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
- 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,
- 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
- 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+ final double[] scores = new double[] {0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+ 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7, 0.7, 0.7,
+ 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 0.6,
+ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4,
+ 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
+ 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+ final int[] labels = new int[] {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,
+ 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
Object[] partials = new Object[3];
@@ -342,7 +317,8 @@ public class AUCUDAFTest {
// merge bins
// merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value
- final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}, {2, 0, 1}};
+ final int[][] orders = new int[][] { {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0},
+ {2, 0, 1}};
for (int j = 0; j < orders.length; j++) {
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
evaluator.reset(agg);
[5/5] incubator-hivemall git commit: Refactored LDA implementation
Posted by my...@apache.org.
Refactored LDA implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e4e1531e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e4e1531e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e4e1531e
Branch: refs/heads/master
Commit: e4e1531e16de51bb934910ec7269895ca51fab4f
Parents: 9669c9d
Author: myui <yu...@gmail.com>
Authored: Thu Apr 20 21:01:36 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 20 21:01:36 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/topicmodel/LDAUDTF.java | 7 +-
.../hivemall/topicmodel/OnlineLDAModel.java | 161 +++++++++----------
.../java/hivemall/utils/math/MathUtils.java | 84 +++++++---
3 files changed, 148 insertions(+), 104 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 91ee7a2..9aa15e2 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -196,8 +196,8 @@ public class LDAUDTF extends UDTFWithOptions {
initModel();
}
- int length = wordCountsOI.getListLength(args[0]);
- String[] wordCounts = new String[length];
+ final int length = wordCountsOI.getListLength(args[0]);
+ final String[] wordCounts = new String[length];
int j = 0;
for (int i = 0; i < length; i++) {
Object o = wordCountsOI.getListElement(args[0], i);
@@ -208,6 +208,9 @@ public class LDAUDTF extends UDTFWithOptions {
wordCounts[j] = s;
j++;
}
+ if (j == 0) {// avoid empty documents
+ return;
+ }
count++;
if (isAutoD) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 890adac..8fef10c 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -90,12 +90,12 @@ public final class OnlineLDAModel {
// for mini-batch
@Nonnull
- private final List<Map<String, Float>> _miniBatchMap;
+ private final List<Map<String, Float>> _miniBatchDocs;
private int _miniBatchSize;
// for computing perplexity
private float _docRatio = 1.f;
- private long _wordCount = 0L;
+ private double _valueSum = 0.d;
public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation
this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta);
@@ -125,15 +125,13 @@ public final class OnlineLDAModel {
// initialize the parameters
this._lambda = new HashMap<String, float[]>(100);
- this._miniBatchMap = new ArrayList<Map<String, Float>>();
+ this._miniBatchDocs = new ArrayList<Map<String, Float>>();
}
/**
- * In a truly online setting, total number of documents corresponds to the number of documents
- * that have ever seen. In that case, users need to manually set the current max number of documents
- * via this method.
- * Note that, since the same set of documents could be repeatedly passed to `train()`,
- * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
+ * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to
+ * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to
+ * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
*/
public void setNumTotalDocs(@Nonnegative long D) {
this._D = D;
@@ -161,34 +159,35 @@ public final class OnlineLDAModel {
}
private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) {
- initMiniBatchMap(miniBatch, _miniBatchMap);
+ initMiniBatch(miniBatch, _miniBatchDocs);
- this._miniBatchSize = _miniBatchMap.size();
+ this._miniBatchSize = _miniBatchDocs.size();
// accumulate the number of words for each documents
- this._wordCount = 0L;
+ double valueSum = 0.d;
for (int d = 0; d < _miniBatchSize; d++) {
- for (float n : _miniBatchMap.get(d).values()) {
- this._wordCount += n;
+ for (Float n : _miniBatchDocs.get(d).values()) {
+ valueSum += n.floatValue();
}
}
+ this._valueSum = valueSum;
this._docRatio = (float) ((double) _D / _miniBatchSize);
}
- private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
- @Nonnull final List<Map<String, Float>> map) {
- map.clear();
+ private static void initMiniBatch(@Nonnull final String[][] miniBatch,
+ @Nonnull final List<Map<String, Float>> docs) {
+ docs.clear();
final FeatureValue probe = new FeatureValue();
// parse document
for (final String[] e : miniBatch) {
- if (e == null) {
+ if (e == null || e.length == 0) {
continue;
}
- final Map<String, Float> docMap = new HashMap<String, Float>();
+ final Map<String, Float> doc = new HashMap<String, Float>();
// parse features
for (String fv : e) {
@@ -198,10 +197,10 @@ public final class OnlineLDAModel {
FeatureValue.parseFeatureAsString(fv, probe);
String label = probe.getFeatureAsString();
float value = probe.getValueAsFloat();
- docMap.put(label, value);
+ doc.put(label, Float.valueOf(value));
}
- map.add(docMap);
+ docs.add(doc);
}
}
@@ -218,7 +217,7 @@ public final class OnlineLDAModel {
final Map<String, float[]> phi_d = new HashMap<String, float[]>();
phi.add(phi_d);
- for (final String label : _miniBatchMap.get(d).keySet()) {
+ for (final String label : _miniBatchDocs.get(d).keySet()) {
phi_d.put(label, new float[_K]);
if (!_lambda.containsKey(label)) { // lambda for newly observed word
_lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd));
@@ -233,19 +232,19 @@ public final class OnlineLDAModel {
private void eStep() {
// since lambda is invariant in the expectation step,
// `digamma`s of lambda values for Elogbeta are pre-computed
- final float[] lambdaSum = new float[_K];
+ final double[] lambdaSum = new double[_K];
final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>();
for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
String label = e.getKey();
float[] lambda_label = e.getValue();
// for digamma(lambdaSum)
- MathUtils.add(lambdaSum, lambda_label, _K);
+ MathUtils.add(lambda_label, lambdaSum, _K);
digamma_lambda.put(label, MathUtils.digamma(lambda_label));
}
- final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
// for each of mini-batch documents, update gamma until convergence
float[] gamma_d, gammaPrev_d;
Map<String, float[]> eLogBeta_d;
@@ -265,11 +264,11 @@ public final class OnlineLDAModel {
@Nonnull
private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d,
@Nonnull final Map<String, float[]> digamma_lambda,
- @Nonnull final float[] digamma_lambdaSum) {
- // Dirichlet expectation (2d) for lambda
- final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ @Nonnull final double[] digamma_lambdaSum) {
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
+ // Dirichlet expectation (2d) for lambda
+ final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(doc.size());
for (final String label : doc.keySet()) {
float[] eLogBeta_label = eLogBeta_d.get(label);
if (eLogBeta_label == null) {
@@ -278,7 +277,7 @@ public final class OnlineLDAModel {
}
final float[] digamma_lambda_label = digamma_lambda.get(label);
for (int k = 0; k < _K; k++) {
- eLogBeta_label[k] = digamma_lambda_label[k] - digamma_lambdaSum[k];
+ eLogBeta_label[k] = (float) (digamma_lambda_label[k] - digamma_lambdaSum[k]);
}
}
@@ -288,28 +287,27 @@ public final class OnlineLDAModel {
private void updatePhiPerDoc(@Nonnegative final int d,
@Nonnull final Map<String, float[]> eLogBeta_d) {
// Dirichlet expectation (2d) for gamma
- final float[] eLogTheta_d = new float[_K];
final float[] gamma_d = _gamma[d];
- final float digamma_gammaSum_d = (float) Gamma.digamma(MathUtils.sum(gamma_d));
+ final double digamma_gammaSum_d = Gamma.digamma(MathUtils.sum(gamma_d));
+ final double[] eLogTheta_d = new double[_K];
for (int k = 0; k < _K; k++) {
- eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
+ eLogTheta_d[k] = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
}
// updating phi w/ normalization
final Map<String, float[]> phi_d = _phi.get(d);
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
for (String label : doc.keySet()) {
final float[] phi_label = phi_d.get(label);
final float[] eLogBeta_label = eLogBeta_d.get(label);
- float normalizer = 0.f;
+ double normalizer = 0.d;
for (int k = 0; k < _K; k++) {
float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f;
phi_label[k] = phiVal;
normalizer += phiVal;
}
- // normalize
for (int k = 0; k < _K; k++) {
phi_label[k] /= normalizer;
}
@@ -317,7 +315,7 @@ public final class OnlineLDAModel {
}
private void updateGammaPerDoc(@Nonnegative final int d) {
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
final Map<String, float[]> phi_d = _phi.get(d);
final float[] gamma_d = _gamma[d];
@@ -326,7 +324,7 @@ public final class OnlineLDAModel {
}
for (Map.Entry<String, Float> e : doc.entrySet()) {
final float[] phi_label = phi_d.get(e.getKey());
- final float val = e.getValue();
+ final float val = e.getValue().floatValue();
for (int k = 0; k < _K; k++) {
gamma_d[k] += phi_label[k] * val;
}
@@ -347,7 +345,7 @@ public final class OnlineLDAModel {
final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>();
for (int d = 0; d < _miniBatchSize; d++) {
final Map<String, float[]> phi_d = _phi.get(d);
- for (String label : _miniBatchMap.get(d).keySet()) {
+ for (String label : _miniBatchDocs.get(d).keySet()) {
float[] lambdaTilde_label = lambdaTilde.get(label);
if (lambdaTilde_label == null) {
lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
@@ -382,73 +380,67 @@ public final class OnlineLDAModel {
* Calculate approximate perplexity for the current mini-batch.
*/
public float computePerplexity() {
- float bound = computeApproxBound();
- float perWordBound = bound / (_docRatio * _wordCount);
- return (float) Math.exp(-1.f * perWordBound);
+ double bound = computeApproxBound();
+ double perWordBound = bound / (_docRatio * _valueSum);
+ return (float) Math.exp(-1.d * perWordBound);
}
/**
* Estimates the variational bound over all documents using only the documents passed as mini-batch.
*/
- private float computeApproxBound() {
+ private double computeApproxBound() {
// prepare
- final float[] gammaSum = new float[_miniBatchSize];
+ final double[] gammaSum = new double[_miniBatchSize];
for (int d = 0; d < _miniBatchSize; d++) {
gammaSum[d] = MathUtils.sum(_gamma[d]);
}
- final float[] digamma_gammaSum = MathUtils.digamma(gammaSum);
+ final double[] digamma_gammaSum = MathUtils.digamma(gammaSum);
- final float[] lambdaSum = new float[_K];
+ final double[] lambdaSum = new double[_K];
for (float[] lambda_label : _lambda.values()) {
- MathUtils.add(lambdaSum, lambda_label, _K);
+ MathUtils.add(lambda_label, lambdaSum, _K);
}
- final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
- final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
- final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
+ final double logGamma_alpha = Gamma.logGamma(_alpha);
+ final double logGamma_alphaSum = Gamma.logGamma(_K * _alpha);
- float score = 0.f;
+ double score = 0.d;
for (int d = 0; d < _miniBatchSize; d++) {
- final float digamma_gammaSum_d = digamma_gammaSum[d];
+ final double digamma_gammaSum_d = digamma_gammaSum[d];
+ final float[] gamma_d = _gamma[d];
// E[log p(doc | theta, beta)]
- for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) {
+ for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) {
final float[] lambda_label = _lambda.get(e.getKey());
// logsumexp( Elogthetad + Elogbetad )
- final float[] temp = new float[_K];
- float max = Float.MIN_VALUE;
+ final double[] temp = new double[_K];
+ double max = Double.MIN_VALUE;
for (int k = 0; k < _K; k++) {
- final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k])
- - digamma_gammaSum_d;
- final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k])
- - digamma_lambdaSum[k];
-
- temp[k] = eLogTheta_dk + eLogBeta_kw;
- if (temp[k] > max) {
- max = temp[k];
+ double eLogTheta_dk = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
+ double eLogBeta_kw = Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+ final double tempK = eLogTheta_dk + eLogBeta_kw;
+ if (tempK > max) {
+ max = tempK;
}
+ temp[k] = tempK;
}
- float logsumexp = 0.f;
- for (int k = 0; k < _K; k++) {
- logsumexp += (float) Math.exp(temp[k] - max);
- }
- logsumexp = max + (float) Math.log(logsumexp);
+ double logsumexp = MathUtils.logsumexp(temp, max);
// sum( word count * logsumexp(...) )
- score += e.getValue() * logsumexp;
+ score += e.getValue().floatValue() * logsumexp;
}
// E[log p(theta | alpha) - log q(theta | gamma)]
for (int k = 0; k < _K; k++) {
- final float gamma_dk = _gamma[d][k];
+ float gamma_dk = gamma_d[k];
// sum( (alpha - gammad) * Elogthetad )
- score += (_alpha - gamma_dk)
- * ((float) Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
+ score += (_alpha - gamma_dk) * (Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
// sum( gammaln(gammad) - gammaln(alpha) )
- score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha;
+ score += Gamma.logGamma(gamma_dk) - logGamma_alpha;
}
score += logGamma_alphaSum; // gammaln(sum(alpha))
score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad))
@@ -458,25 +450,25 @@ public final class OnlineLDAModel {
// (i.e., online setting); likelihood should be always roughly on the same scale
score *= _docRatio;
- final float logGamma_eta = (float) Gamma.logGamma(_eta);
- final float logGamma_etaSum = (float) Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta
+ final double logGamma_eta = Gamma.logGamma(_eta);
+ final double logGamma_etaSum = Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta
// E[log p(beta | eta) - log q (beta | lambda)]
- for (float[] lambda_label : _lambda.values()) {
+ for (final float[] lambda_label : _lambda.values()) {
for (int k = 0; k < _K; k++) {
- final float lambda_k = lambda_label[k];
+ float lambda_label_k = lambda_label[k];
// sum( (eta - lambda) * Elogbeta )
- score += (_eta - lambda_k)
- * (float) (Gamma.digamma(lambda_k) - digamma_lambdaSum[k]);
+ score += (_eta - lambda_label_k)
+ * (Gamma.digamma(lambda_label_k) - digamma_lambdaSum[k]);
// sum( gammaln(lambda) - gammaln(eta) )
- score += (float) Gamma.logGamma(lambda_k) - logGamma_eta;
+ score += Gamma.logGamma(lambda_label_k) - logGamma_eta;
}
}
for (int k = 0; k < _K; k++) {
// sum( gammaln(etaSum) - gammaln( lambdaSum_k )
- score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]);
+ score += logGamma_etaSum - Gamma.logGamma(lambdaSum[k]);
}
return score;
@@ -513,7 +505,7 @@ public final class OnlineLDAModel {
@Nonnull
public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k,
@Nonnegative int topN) {
- float lambdaSum = 0.f;
+ double lambdaSum = 0.d;
final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>(
Collections.reverseOrder());
@@ -535,7 +527,8 @@ public final class OnlineLDAModel {
topN = Math.min(topN, _lambda.keySet().size());
int tt = 0;
for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) {
- ret.put(e.getKey() / lambdaSum, e.getValue());
+ float key = (float) (e.getKey().floatValue() / lambdaSum);
+ ret.put(Float.valueOf(key), e.getValue());
if (++tt == topN) {
break;
@@ -556,9 +549,9 @@ public final class OnlineLDAModel {
// normalize topic distribution
final float[] topicDistr = new float[_K];
final float[] gamma0 = _gamma[0];
- final float gammaSum = MathUtils.sum(gamma0);
+ final double gammaSum = MathUtils.sum(gamma0);
for (int k = 0; k < _K; k++) {
- topicDistr[k] = gamma0[k] / gammaSum;
+ topicDistr[k] = (float) (gamma0[k] / gammaSum);
}
return topicDistr;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 7fdea55..71d0270 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -314,44 +314,92 @@ public final class MathUtils {
return perm;
}
- public static float sum(@Nullable final float[] a) {
- if (a == null) {
- return 0.f;
+ public static double sum(@Nullable final float[] arr) {
+ if (arr == null) {
+ return 0.d;
}
- float sum = 0.f;
- for (float v : a) {
+ double sum = 0.d;
+ for (float v : arr) {
sum += v;
}
return sum;
}
- public static float sum(@Nullable final float[] a, @Nonnegative final int size) {
- if (a == null) {
- return 0.f;
- }
-
- float sum = 0.f;
+ public static void add(@Nonnull final float[] src, @Nonnull final float[] dst, final int size) {
for (int i = 0; i < size; i++) {
- sum += a[i];
+ dst[i] += src[i];
}
- return sum;
}
- public static void add(@Nonnull final float[] dst, @Nonnull final float[] toAdd, final int size) {
+ public static void add(@Nonnull final float[] src, @Nonnull final double[] dst, final int size) {
for (int i = 0; i < size; i++) {
- dst[i] += toAdd[i];
+ dst[i] += src[i];
}
}
@Nonnull
- public static float[] digamma(@Nonnull final float[] a) {
- final int k = a.length;
+ public static float[] digamma(@Nonnull final float[] arr) {
+ final int k = arr.length;
final float[] ret = new float[k];
for (int i = 0; i < k; i++) {
- ret[i] = (float) Gamma.digamma(a[i]);
+ ret[i] = (float) Gamma.digamma(arr[i]);
}
return ret;
}
+ @Nonnull
+ public static double[] digamma(@Nonnull final double[] arr) {
+ final int k = arr.length;
+ final double[] ret = new double[k];
+ for (int i = 0; i < k; i++) {
+ ret[i] = Gamma.digamma(arr[i]);
+ }
+ return ret;
+ }
+
+ public static float logsumexp(@Nonnull final float[] arr) {
+ if (arr.length == 0) {
+ return 0.f;
+ }
+ float max = 0.f;
+ for (final float v : arr) {
+ if (v > max) {
+ max = v;
+ }
+ }
+ return logsumexp(arr, max);
+ }
+
+ public static float logsumexp(@Nonnull final float[] arr, final float max) {
+ double logsumexp = 0.d;
+ for (final float v : arr) {
+ logsumexp += Math.exp(v - max);
+ }
+ logsumexp = Math.log(logsumexp) + max;
+ return (float) logsumexp;
+ }
+
+ public static double logsumexp(@Nonnull final double[] arr) {
+ if (arr.length == 0) {
+ return 0.d;
+ }
+ double max = 0.d;
+ for (final double v : arr) {
+ if (v > max) {
+ max = v;
+ }
+ }
+ return logsumexp(arr, max);
+ }
+
+ public static double logsumexp(@Nonnull final double[] arr, final double max) {
+ double logsumexp = 0.d;
+ for (final double v : arr) {
+ logsumexp += Math.exp(v - max);
+ }
+ logsumexp = Math.log(logsumexp) + max;
+ return logsumexp;
+ }
+
}
[3/5] incubator-hivemall git commit: Refactored LDA implementation
Posted by my...@apache.org.
Refactored LDA implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1f98970b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1f98970b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1f98970b
Branch: refs/heads/master
Commit: 1f98970bb5f010936bdee7a9610a156e20473696
Parents: 9b2ddcc
Author: myui <yu...@gmail.com>
Authored: Thu Apr 20 17:14:37 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 20 17:14:37 2017 +0900
----------------------------------------------------------------------
.../hivemall/topicmodel/OnlineLDAModel.java | 86 +++++++++++---------
.../java/hivemall/utils/lang/ArrayUtils.java | 4 +-
2 files changed, 51 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 3e7ad10..890adac 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -39,6 +39,9 @@ import org.apache.commons.math3.special.Gamma;
public final class OnlineLDAModel {
+ // ---------------------------------
+ // HyperParameters
+
// number of topics
private final int _K;
@@ -52,25 +55,25 @@ public final class OnlineLDAModel {
// in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen
private long _D = -1L;
- // defined by (tau0 + updateCount)^(-kappa_)
- // controls how much old lambda is forgotten
- private double _rhot;
-
// positive value which downweights early iterations
@Nonnegative
private final double _tau0;
// exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence
+ @Nonnegative
private final double _kappa;
+ // check convergence in the expectation (E) step
+ private final double _delta;
+
+ // ---------------------------------
+
// how many times EM steps are launched; later EM steps do not drastically forget old lambda
private long _updateCount = 0L;
- // random number generator
- @Nonnull
- private final GammaDistribution _gd;
- private static final double SHAPE = 100.d;
- private static final double SCALE = 1.d / SHAPE;
+ // defined by (tau0 + updateCount)^(-kappa_)
+ // controls how much old lambda is forgotten
+ private double _rhot;
// parameters
@Nonnull
@@ -79,9 +82,13 @@ public final class OnlineLDAModel {
@Nonnull
private final Map<String, float[]> _lambda;
- // check convergence in the expectation (E) step
- private final double _delta;
+ // random number generator
+ @Nonnull
+ private final GammaDistribution _gd;
+ private static final double SHAPE = 100.d;
+ private static final double SCALE = 1.d / SHAPE;
+ // for mini-batch
@Nonnull
private final List<Map<String, Float>> _miniBatchMap;
private int _miniBatchSize;
@@ -134,7 +141,8 @@ public final class OnlineLDAModel {
public void train(@Nonnull final String[][] miniBatch) {
if (_D <= 0L) {
- throw new RuntimeException("Total number of documents MUST be set via `setNumTotalDocs()`");
+ throw new IllegalStateException(
+ "Total number of documents MUST be set via `setNumTotalDocs()`");
}
preprocessMiniBatch(miniBatch);
@@ -165,7 +173,7 @@ public final class OnlineLDAModel {
}
}
- this._docRatio = (float)((double) _D / _miniBatchSize);
+ this._docRatio = (float) ((double) _D / _miniBatchSize);
}
private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
@@ -197,26 +205,29 @@ public final class OnlineLDAModel {
}
}
- private void initParams(boolean gammaWithRandom) {
- _phi = new ArrayList<Map<String, float[]>>();
- _gamma = new float[_miniBatchSize][];
+ private void initParams(final boolean gammaWithRandom) {
+ final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>();
+ final float[][] gamma = new float[_miniBatchSize][];
for (int d = 0; d < _miniBatchSize; d++) {
if (gammaWithRandom) {
- _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
+ gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
} else {
- _gamma[d] = ArrayUtils.newInstance(_K, 1.f);
+ gamma[d] = ArrayUtils.newFloatArray(_K, 1.f);
}
final Map<String, float[]> phi_d = new HashMap<String, float[]>();
- _phi.add(phi_d);
- for (String label : _miniBatchMap.get(d).keySet()) {
+ phi.add(phi_d);
+ for (final String label : _miniBatchMap.get(d).keySet()) {
phi_d.put(label, new float[_K]);
if (!_lambda.containsKey(label)) { // lambda for newly observed word
_lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd));
}
}
}
+
+ this._phi = phi;
+ this._gamma = gamma;
}
private void eStep() {
@@ -231,22 +242,19 @@ public final class OnlineLDAModel {
// for digamma(lambdaSum)
MathUtils.add(lambdaSum, lambda_label, _K);
- float[] digamma_lambda_label = new float[_K];
digamma_lambda.put(label, MathUtils.digamma(lambda_label));
}
- final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ // for each of mini-batch documents, update gamma until convergence
float[] gamma_d, gammaPrev_d;
Map<String, float[]> eLogBeta_d;
-
- // for each of mini-batch documents, update gamma until convergence
for (int d = 0; d < _miniBatchSize; d++) {
gamma_d = _gamma[d];
eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum);
do {
- // (deep) copy the last gamma values
- gammaPrev_d = gamma_d.clone();
+ gammaPrev_d = gamma_d.clone(); // deep copy the last gamma values
updatePhiPerDoc(d, eLogBeta_d);
updateGammaPerDoc(d);
@@ -256,12 +264,13 @@ public final class OnlineLDAModel {
@Nonnull
private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d,
- @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] digamma_lambdaSum) {
+ @Nonnull final Map<String, float[]> digamma_lambda,
+ @Nonnull final float[] digamma_lambdaSum) {
// Dirichlet expectation (2d) for lambda
final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
final Map<String, Float> doc = _miniBatchMap.get(d);
- for (String label : doc.keySet()) {
+ for (final String label : doc.keySet()) {
float[] eLogBeta_label = eLogBeta_d.get(label);
if (eLogBeta_label == null) {
eLogBeta_label = new float[_K];
@@ -276,7 +285,8 @@ public final class OnlineLDAModel {
return eLogBeta_d;
}
- private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull Map<String, float[]> eLogBeta_d) {
+ private void updatePhiPerDoc(@Nonnegative final int d,
+ @Nonnull final Map<String, float[]> eLogBeta_d) {
// Dirichlet expectation (2d) for gamma
final float[] eLogTheta_d = new float[_K];
final float[] gamma_d = _gamma[d];
@@ -288,7 +298,7 @@ public final class OnlineLDAModel {
// updating phi w/ normalization
final Map<String, float[]> phi_d = _phi.get(d);
final Map<String, Float> doc = _miniBatchMap.get(d);
- for (String label : doc.keySet()) {
+ for (String label : doc.keySet()) {
final float[] phi_label = phi_d.get(label);
final float[] eLogBeta_label = eLogBeta_d.get(label);
@@ -340,7 +350,7 @@ public final class OnlineLDAModel {
for (String label : _miniBatchMap.get(d).keySet()) {
float[] lambdaTilde_label = lambdaTilde.get(label);
if (lambdaTilde_label == null) {
- lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+ lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
lambdaTilde.put(label, lambdaTilde_label);
}
@@ -358,7 +368,7 @@ public final class OnlineLDAModel {
float[] lambdaTilde_label = lambdaTilde.get(label);
if (lambdaTilde_label == null) {
- lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+ lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
}
for (int k = 0; k < _K; k++) {
@@ -381,8 +391,6 @@ public final class OnlineLDAModel {
* Estimates the variational bound over all documents using only the documents passed as mini-batch.
*/
private float computeApproxBound() {
- float score = 0.f;
-
// prepare
final float[] gammaSum = new float[_miniBatchSize];
for (int d = 0; d < _miniBatchSize; d++) {
@@ -399,6 +407,7 @@ public final class OnlineLDAModel {
final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
+ float score = 0.f;
for (int d = 0; d < _miniBatchSize; d++) {
final float digamma_gammaSum_d = digamma_gammaSum[d];
@@ -410,8 +419,10 @@ public final class OnlineLDAModel {
final float[] temp = new float[_K];
float max = Float.MIN_VALUE;
for (int k = 0; k < _K; k++) {
- final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d;
- final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+ final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k])
+ - digamma_gammaSum_d;
+ final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k])
+ - digamma_lambdaSum[k];
temp[k] = eLogTheta_dk + eLogBeta_kw;
if (temp[k] > max) {
@@ -484,7 +495,8 @@ public final class OnlineLDAModel {
return lambda_label[k];
}
- public void setLambda(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) {
+ public void setLambda(@Nonnull final String label, @Nonnegative final int k,
+ final float lambda_k) {
float[] lambda_label = _lambda.get(label);
if (lambda_label == null) {
lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f98970b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 711aac7..c20c363 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -719,12 +719,12 @@ public final class ArrayUtils {
}
@Nonnull
- public static float[] newInstance(@Nonnegative int size, float filledValue) {
+ public static float[] newFloatArray(@Nonnegative int size, float filledValue) {
final float[] a = new float[size];
Arrays.fill(a, filledValue);
return a;
}
-
+
@Nonnull
public static float[] newRandomFloatArray(@Nonnegative final int size,
@Nonnull final GammaDistribution gd) {
[2/5] incubator-hivemall git commit: Close #66: [HIVEMALL-91]
Implement Online LDA
Posted by my...@apache.org.
Close #66: [HIVEMALL-91] Implement Online LDA
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9b2ddcc7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9b2ddcc7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9b2ddcc7
Branch: refs/heads/master
Commit: 9b2ddcc76b0950124373a30c1dbc56acff664ebf
Parents: bba252a
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Thu Apr 20 16:33:20 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 20 16:33:20 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/model/FeatureValue.java | 4 +
.../hivemall/topicmodel/LDAPredictUDAF.java | 476 ++++++++++++++++
.../main/java/hivemall/topicmodel/LDAUDTF.java | 567 +++++++++++++++++++
.../hivemall/topicmodel/OnlineLDAModel.java | 554 ++++++++++++++++++
.../java/hivemall/utils/lang/ArrayUtils.java | 20 +
.../java/hivemall/utils/math/MathUtils.java | 43 ++
.../hivemall/topicmodel/LDAPredictUDAFTest.java | 228 ++++++++
.../java/hivemall/topicmodel/LDAUDTFTest.java | 104 ++++
.../hivemall/topicmodel/OnlineLDAModelTest.java | 252 +++++++++
docs/gitbook/SUMMARY.md | 8 +-
docs/gitbook/clustering/lda.md | 195 +++++++
resources/ddl/define-all-as-permanent.hive | 10 +
resources/ddl/define-all.hive | 10 +
resources/ddl/define-all.spark | 10 +
resources/ddl/define-udfs.td.hql | 2 +
15 files changed, 2481 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/model/FeatureValue.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java
index 39fadaf..11aa8f0 100644
--- a/core/src/main/java/hivemall/model/FeatureValue.java
+++ b/core/src/main/java/hivemall/model/FeatureValue.java
@@ -54,6 +54,10 @@ public final class FeatureValue {
return ((Integer) feature).intValue();
}
+ public String getFeatureAsString() {
+ return feature.toString();
+ }
+
public double getValue() {
return value;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
new file mode 100644
index 0000000..811af2e
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.CommandLineUtils;
+import hivemall.utils.lang.Primitives;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "lda_predict",
+ value = "_FUNC_(string word, float value, int label, float lambda[, const string options])"
+ + " - Returns a list which consists of <int label, float prob>")
+public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
+
+ @Override
+ public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 4 && typeInfo.length != 5) {
+ throw new UDFArgumentLengthException(
+ "Expected argument length is 4 or 5 but given argument length was "
+ + typeInfo.length);
+ }
+
+ if (!HiveUtils.isStringTypeInfo(typeInfo[0])) {
+ throw new UDFArgumentTypeException(0,
+ "String type is expected for the first argument word: " + typeInfo[0].getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Number type is expected for the second argument value: "
+ + typeInfo[1].getTypeName());
+ }
+ if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) {
+ throw new UDFArgumentTypeException(2,
+ "Integer type is expected for the third argument label: "
+ + typeInfo[2].getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+ throw new UDFArgumentTypeException(3,
+ "Number type is expected for the forth argument lambda: "
+ + typeInfo[3].getTypeName());
+ }
+
+ if (typeInfo.length == 5) {
+ if (!HiveUtils.isStringTypeInfo(typeInfo[4])) {
+ throw new UDFArgumentTypeException(4,
+ "String type is expected for the fifth argument lambda: "
+ + typeInfo[4].getTypeName());
+ }
+ }
+
+ return new Evaluator();
+ }
+
+ public static class Evaluator extends GenericUDAFEvaluator {
+
+ // input OI
+ private PrimitiveObjectInspector wordOI;
+ private PrimitiveObjectInspector valueOI;
+ private PrimitiveObjectInspector labelOI;
+ private PrimitiveObjectInspector lambdaOI;
+
+ // Hyperparameters
+ private int topic;
+ private float alpha;
+ private double delta;
+
+ // merge OI
+ private StructObjectInspector internalMergeOI;
+ private StructField wcListField;
+ private StructField lambdaMapField;
+ private StructField topicOptionField;
+ private StructField alphaOptionField;
+ private StructField deltaOptionField;
+ private PrimitiveObjectInspector wcListElemOI;
+ private StandardListObjectInspector wcListOI;
+ private StandardMapObjectInspector lambdaMapOI;
+ private PrimitiveObjectInspector lambdaMapKeyOI;
+ private StandardListObjectInspector lambdaMapValueOI;
+ private PrimitiveObjectInspector lambdaMapValueElemOI;
+
+ public Evaluator() {}
+
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("k", "topic", true, "The number of topics [required]");
+ opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
+ opts.addOption("delta", true,
+ "Check convergence in the expectation step [default: 1E-5]");
+ return opts;
+ }
+
+ @Nonnull
+ protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException {
+ String[] args = optionValue.split("\\s+");
+ Options opts = getOptions();
+ opts.addOption("help", false, "Show function help");
+ CommandLine cl = CommandLineUtils.parseOptions(args, opts);
+
+ if (cl.hasOption("help")) {
+ Description funcDesc = getClass().getAnnotation(Description.class);
+ final String cmdLineSyntax;
+ if (funcDesc == null) {
+ cmdLineSyntax = getClass().getSimpleName();
+ } else {
+ String funcName = funcDesc.name();
+ cmdLineSyntax = funcName == null ? getClass().getSimpleName()
+ : funcDesc.value().replace("_FUNC_", funcDesc.name());
+ }
+ StringWriter sw = new StringWriter();
+ sw.write('\n');
+ PrintWriter pw = new PrintWriter(sw);
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts,
+ HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true);
+ pw.flush();
+ String helpMsg = sw.toString();
+ throw new UDFArgumentException(helpMsg);
+ }
+
+ return cl;
+ }
+
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+
+ if (argOIs.length != 5) {
+ throw new UDFArgumentException("At least 1 option `-topic` MUST be specified");
+ }
+
+ String rawArgs = HiveUtils.getConstString(argOIs[4]);
+ cl = parseOptions(rawArgs);
+
+ this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0);
+ if (topic < 1) {
+ throw new UDFArgumentException(
+ "A positive integer MUST be set to an option `-topic`: " + topic);
+ }
+
+ this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
+ this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d);
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 4 || parameters.length == 5);
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ processOptions(parameters);
+ this.wordOI = HiveUtils.asStringOI(parameters[0]);
+ this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]);
+ this.labelOI = HiveUtils.asIntegerOI(parameters[2]);
+ this.lambdaOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.wcListField = soi.getStructFieldRef("wcList");
+ this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
+ this.topicOptionField = soi.getStructFieldRef("topic");
+ this.alphaOptionField = soi.getStructFieldRef("alpha");
+ this.deltaOptionField = soi.getStructFieldRef("delta");
+ this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI);
+ this.lambdaMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ this.lambdaMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ this.lambdaMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(lambdaMapValueElemOI);
+ this.lambdaMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ lambdaMapKeyOI, lambdaMapValueOI);
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {
+ final ArrayList<String> fieldNames = new ArrayList<String>();
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("label");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("probability");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(
+ fieldNames, fieldOIs));
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("wcList");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+ fieldNames.add("lambdaMap");
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+ fieldNames.add("topic");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+ fieldNames.add("alpha");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ fieldNames.add("delta");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @SuppressWarnings("deprecation")
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ AggregationBuffer myAggr = new OnlineLDAPredictAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
+ myAggr.reset();
+ myAggr.setOptions(topic, alpha, delta);
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
+
+ if (parameters[0] == null || parameters[1] == null || parameters[2] == null
+ || parameters[3] == null) {
+ return;
+ }
+
+ String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI);
+ float value = HiveUtils.getFloat(parameters[1], valueOI);
+ int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI);
+ float lambda = HiveUtils.getFloat(parameters[3], lambdaOI);
+
+ myAggr.iterate(word, value, label, lambda);
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
+ if (myAggr.wcList.size() == 0) {
+ return null;
+ }
+
+ Object[] partialResult = new Object[5];
+ partialResult[0] = myAggr.wcList;
+ partialResult[1] = myAggr.lambdaMap;
+ partialResult[2] = new IntWritable(myAggr.topic);
+ partialResult[3] = new FloatWritable(myAggr.alpha);
+ partialResult[4] = new DoubleWritable(myAggr.delta);
+
+ return partialResult;
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+ throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField);
+
+ List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj));
+
+ // fix list elements to Java String objects
+ int wcListSize = wcListRaw.size();
+ List<String> wcList = new ArrayList<String>();
+ for (int i = 0; i < wcListSize; i++) {
+ wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI));
+ }
+
+ Object lambdaMapObj = internalMergeOI.getStructFieldData(partial, lambdaMapField);
+ Map<?, ?> lambdaMapRaw = lambdaMapOI.getMap(HiveUtils.castLazyBinaryObject(lambdaMapObj));
+
+ Map<String, List<Float>> lambdaMap = new HashMap<String, List<Float>>();
+ for (Map.Entry<?, ?> e : lambdaMapRaw.entrySet()) {
+ // fix map keys to Java String objects
+ String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), lambdaMapKeyOI);
+
+ Object lambdaMapValueObj = e.getValue();
+ List<?> lambdaMapValueRaw = lambdaMapValueOI.getList(HiveUtils.castLazyBinaryObject(lambdaMapValueObj));
+
+ // fix map values to lists of Java Float objects
+ int lambdaMapValueSize = lambdaMapValueRaw.size();
+ List<Float> lambda_word = new ArrayList<Float>();
+ for (int i = 0; i < lambdaMapValueSize; i++) {
+ lambda_word.add(HiveUtils.getFloat(lambdaMapValueRaw.get(i),
+ lambdaMapValueElemOI));
+ }
+
+ lambdaMap.put(word, lambda_word);
+ }
+
+ // restore options from partial result
+ Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
+ this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+
+ Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
+ this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
+
+ Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField);
+ this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
+
+ OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
+ myAggr.setOptions(topic, alpha, delta);
+ myAggr.merge(wcList, lambdaMap);
+ }
+
+ @Override
+ public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
+ float[] topicDistr = myAggr.get();
+
+ SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>(
+ Collections.reverseOrder());
+ for (int i = 0; i < topicDistr.length; i++) {
+ sortedDistr.put(topicDistr[i], i);
+ }
+
+ List<Object[]> result = new ArrayList<Object[]>();
+ for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+ Object[] struct = new Object[2];
+ struct[0] = new IntWritable(e.getValue()); // label
+ struct[1] = new FloatWritable(e.getKey()); // probability
+ result.add(struct);
+ }
+ return result;
+ }
+
+ }
+
+ public static class OnlineLDAPredictAggregationBuffer extends
+ GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+ private List<String> wcList;
+ private Map<String, List<Float>> lambdaMap;
+
+ private int topic;
+ private float alpha;
+ private double delta;
+
+ OnlineLDAPredictAggregationBuffer() {
+ super();
+ }
+
+ void setOptions(int topic, float alpha, double delta) {
+ this.topic = topic;
+ this.alpha = alpha;
+ this.delta = delta;
+ }
+
+ void reset() {
+ this.wcList = new ArrayList<String>();
+ this.lambdaMap = new HashMap<String, List<Float>>();
+ }
+
+ void iterate(String word, float value, int label, float lambda) {
+ wcList.add(word + ":" + value);
+
+ // for an unforeseen word, initialize its lambdas w/ -1s
+ if (!lambdaMap.containsKey(word)) {
+ List<Float> lambdaEmpty_word = new ArrayList<Float>(
+ Collections.nCopies(topic, -1.f));
+ lambdaMap.put(word, lambdaEmpty_word);
+ }
+
+ // set the given lambda value
+ List<Float> lambda_word = lambdaMap.get(word);
+ lambda_word.set(label, lambda);
+ lambdaMap.put(word, lambda_word);
+ }
+
+ void merge(List<String> o_wcList, Map<String, List<Float>> o_lambdaMap) {
+ wcList.addAll(o_wcList);
+
+ for (Map.Entry<String, List<Float>> e : o_lambdaMap.entrySet()) {
+ String o_word = e.getKey();
+ List<Float> o_lambda_word = e.getValue();
+
+ if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word
+ lambdaMap.put(o_word, o_lambda_word);
+ } else { // for a partially observed word
+ List<Float> lambda_word = lambdaMap.get(o_word);
+ for (int k = 0; k < topic; k++) {
+ if (o_lambda_word.get(k) != -1.f) { // not default value
+ lambda_word.set(k, o_lambda_word.get(k)); // set the partial lambda value
+ }
+ }
+ lambdaMap.put(o_word, lambda_word);
+ }
+ }
+ }
+
+ float[] get() {
+ OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta);
+
+ for (String word : lambdaMap.keySet()) {
+ List<Float> lambda_word = lambdaMap.get(word);
+ for (int k = 0; k < topic; k++) {
+ model.setLambda(word, k, lambda_word.get(k));
+ }
+ }
+
+ String[] wcArray = wcList.toArray(new String[wcList.size()]);
+ return model.getTopicDistribution(wcArray);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
new file mode 100644
index 0000000..91ee7a2
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -0,0 +1,567 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
+ + " - Returns a relation consists of <int topic, string word, float score>")
+public class LDAUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+
+ // Options
+ protected int topic;
+ protected float alpha;
+ protected float eta;
+ protected long numDocs;
+ protected double tau0;
+ protected double kappa;
+ protected int iterations;
+ protected double delta;
+ protected double eps;
+ protected int miniBatchSize;
+
+ // if `num_docs` option is not given, this flag will be true
+ // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
+ protected boolean isAutoD;
+
+ // number of proceeded training samples
+ protected long count;
+
+ protected String[][] miniBatch;
+ protected int miniBatchCount;
+
+ protected transient OnlineLDAModel model;
+
+ protected ListObjectInspector wordCountsOI;
+
+ // for iterations
+ protected NioStatefullSegment fileIO;
+ protected ByteBuffer inputBuf;
+
+ public LDAUDTF() {
+ this.topic = 10;
+ this.alpha = 1.f / topic;
+ this.eta = 1.f / topic;
+ this.numDocs = -1L;
+ this.tau0 = 64.d;
+ this.kappa = 0.7;
+ this.iterations = 10;
+ this.delta = 1E-3d;
+ this.eps = 1E-1d;
+ this.miniBatchSize = 128; // if 1, truly online setting
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("k", "topic", true, "The number of topics [default: 10]");
+ opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
+ opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
+ opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]");
+ opts.addOption("tau", "tau0", true,
+ "The parameter which downweights early iterations [default: 64.0]");
+ opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]");
+ opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
+ opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
+ opts.addOption("eps", "epsilon", true,
+ "Check convergence based on the difference of perplexity [default: 1E-1]");
+ opts.addOption("s", "mini_batch_size", true,
+ "Repeat model updating per mini-batch [default: 128]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+
+ if (argOIs.length >= 2) {
+ String rawArgs = HiveUtils.getConstString(argOIs[1]);
+ cl = parseOptions(rawArgs);
+ this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10);
+ this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
+ this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic);
+ this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
+ this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d);
+ if (tau0 <= 0.d) {
+ throw new UDFArgumentException("'-tau0' must be positive: " + tau0);
+ }
+ this.kappa = Primitives.parseDouble(cl.getOptionValue("kappa"), 0.7d);
+ if (kappa <= 0.5 || kappa > 1.d) {
+ throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa);
+ }
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+ if (iterations < 1) {
+ throw new UDFArgumentException(
+ "'-iterations' must be greater than or equals to 1: " + iterations);
+ }
+ this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-3d);
+ this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
+ this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+ }
+
+ return cl;
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length < 1) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
+ }
+
+ this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+ HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+ processOptions(argOIs);
+
+ this.model = null;
+ this.count = 0L;
+ this.isAutoD = (numDocs < 0L);
+ this.miniBatch = new String[miniBatchSize][];
+ this.miniBatchCount = 0;
+
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("topic");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("word");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ fieldNames.add("score");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ protected void initModel() {
+ this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, kappa, delta);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ initModel();
+ }
+
+ int length = wordCountsOI.getListLength(args[0]);
+ String[] wordCounts = new String[length];
+ int j = 0;
+ for (int i = 0; i < length; i++) {
+ Object o = wordCountsOI.getListElement(args[0], i);
+ if (o == null) {
+ throw new HiveException("Given feature vector contains invalid elements");
+ }
+ String s = o.toString();
+ wordCounts[j] = s;
+ j++;
+ }
+
+ count++;
+ if (isAutoD) {
+ model.setNumTotalDocs(count);
+ }
+
+ recordTrainSampleToTempFile(wordCounts);
+
+ miniBatch[miniBatchCount] = wordCounts;
+ miniBatchCount++;
+
+ if (miniBatchCount == miniBatchSize) {
+ model.train(miniBatch);
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+ }
+ }
+
+ protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
+ throws HiveException {
+ if (iterations == 1) {
+ return;
+ }
+
+ ByteBuffer buf = inputBuf;
+ NioStatefullSegment dst = fileIO;
+
+ if (buf == null) {
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_lda", ".sgmt");
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException("Cannot write a temporary file: "
+ + file.getAbsolutePath());
+ }
+ logger.info("Record training samples to a file: " + file.getAbsolutePath());
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ } catch (Throwable e) {
+ throw new UDFArgumentException(e);
+ }
+ this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
+ this.fileIO = dst = new NioStatefullSegment(file, false);
+ }
+
+ int wcLength = 0;
+ for (String wc : wordCounts) {
+ if (wc == null) {
+ continue;
+ }
+ wcLength += wc.getBytes().length;
+ }
+ // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
+ int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength;
+ int remain = buf.remaining();
+ if (remain < recordBytes) {
+ writeBuffer(buf, dst);
+ }
+
+ buf.putInt(recordBytes);
+ buf.putInt(wordCounts.length);
+ for (String wc : wordCounts) {
+ if (wc == null) {
+ continue;
+ }
+ buf.putInt(wc.length());
+ buf.put(wc.getBytes());
+ }
+ }
+
+ private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
+ throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ @Override
+ public void close() throws HiveException {
+ if (count == 0) {
+ this.model = null;
+ return;
+ }
+ if (miniBatchCount > 0) { // update for remaining samples
+ model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+ }
+ if (iterations > 1) {
+ runIterativeTraining(iterations);
+ }
+ forwardModel();
+ this.model = null;
+ }
+
+ protected final void runIterativeTraining(@Nonnegative final int iterations)
+ throws HiveException {
+ final ByteBuffer buf = this.inputBuf;
+ final NioStatefullSegment dst = this.fileIO;
+ assert (buf != null);
+ assert (dst != null);
+ final long numTrainingExamples = count;
+
+ final Reporter reporter = getReporter();
+ final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+ "hivemall.lda.OnlineLDA$Counter", "iteration");
+
+ try {
+ if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+ if (buf.position() == 0) {
+ return; // no training example
+ }
+ buf.flip();
+
+ int iter = 2;
+ float perplexityPrev = Float.MAX_VALUE;
+ float perplexity;
+ int numTrain;
+ for (; iter <= iterations; iter++) {
+ perplexity = 0.f;
+ numTrain = 0;
+
+ reportProgress(reporter);
+ setCounterValue(iterCounter, iter);
+
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+
+ while (buf.remaining() > 0) {
+ int recordBytes = buf.getInt();
+ assert (recordBytes > 0) : recordBytes;
+ int wcLength = buf.getInt();
+ final String[] wordCounts = new String[wcLength];
+ for (int j = 0; j < wcLength; j++) {
+ int len = buf.getInt();
+ byte[] bytes = new byte[len];
+ buf.get(bytes);
+ wordCounts[j] = new String(bytes);
+ }
+
+ miniBatch[miniBatchCount] = wordCounts;
+ miniBatchCount++;
+
+ if (miniBatchCount == miniBatchSize) {
+ model.train(miniBatch);
+ perplexity += model.computePerplexity();
+ numTrain++;
+
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+ }
+ }
+ buf.rewind();
+
+ // update for remaining samples
+ if (miniBatchCount > 0) { // update for remaining samples
+ model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+ perplexity += model.computePerplexity();
+ numTrain++;
+ }
+
+ logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+ perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+ if (Math.abs(perplexityPrev - perplexity) < eps) {
+ break;
+ }
+ perplexityPrev = perplexity;
+ }
+ logger.info("Performed "
+ + Math.min(iter, iterations)
+ + " iterations of "
+ + NumberUtils.formatNumber(numTrainingExamples)
+ + " training examples on memory (thus "
+ + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+ + " training updates in total) ");
+ } else {// read training examples in the temporary file and invoke train for each example
+
+ // write training examples in buffer to a temporary file
+ if (buf.remaining() > 0) {
+ writeBuffer(buf, dst);
+ }
+ try {
+ dst.flush();
+ } catch (IOException e) {
+ throw new HiveException("Failed to flush a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (logger.isInfoEnabled()) {
+ File tmpFile = dst.getFile();
+ logger.info("Wrote " + numTrainingExamples
+ + " records to a temporary file for iterative training: "
+ + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+ + ")");
+ }
+
+ // run iterations
+ int iter = 2;
+ float perplexityPrev = Float.MAX_VALUE;
+ float perplexity;
+ int numTrain;
+ for (; iter <= iterations; iter++) {
+ perplexity = 0.f;
+ numTrain = 0;
+
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+
+ setCounterValue(iterCounter, iter);
+
+ buf.clear();
+ dst.resetPosition();
+ while (true) {
+ reportProgress(reporter);
+ // TODO prefetch
+ // writes training examples to a buffer in the temporary file
+ final int bytesRead;
+ try {
+ bytesRead = dst.read(buf);
+ } catch (IOException e) {
+ throw new HiveException("Failed to read a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (bytesRead == 0) { // reached file EOF
+ break;
+ }
+ assert (bytesRead > 0) : bytesRead;
+
+ // reads training examples from a buffer
+ buf.flip();
+ int remain = buf.remaining();
+ if (remain < SizeOf.INT) {
+ throw new HiveException("Illegal file format was detected");
+ }
+ while (remain >= SizeOf.INT) {
+ int pos = buf.position();
+ int recordBytes = buf.getInt();
+ remain -= SizeOf.INT;
+ if (remain < recordBytes) {
+ buf.position(pos);
+ break;
+ }
+
+ int wcLength = buf.getInt();
+ final String[] wordCounts = new String[wcLength];
+ for (int j = 0; j < wcLength; j++) {
+ int len = buf.getInt();
+ byte[] bytes = new byte[len];
+ buf.get(bytes);
+ wordCounts[j] = new String(bytes);
+ }
+
+ miniBatch[miniBatchCount] = wordCounts;
+ miniBatchCount++;
+
+ if (miniBatchCount == miniBatchSize) {
+ model.train(miniBatch);
+ perplexity += model.computePerplexity();
+ numTrain++;
+
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+ }
+
+ remain -= recordBytes;
+ }
+ buf.compact();
+ }
+
+ // update for remaining samples
+ if (miniBatchCount > 0) { // update for remaining samples
+ model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+ perplexity += model.computePerplexity();
+ numTrain++;
+ }
+
+ logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+ perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+ if (Math.abs(perplexityPrev - perplexity) < eps) {
+ break;
+ }
+ perplexityPrev = perplexity;
+ }
+ logger.info("Performed "
+ + Math.min(iter, iterations)
+ + " iterations of "
+ + NumberUtils.formatNumber(numTrainingExamples)
+ + " training examples on a secondary storage (thus "
+ + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+ + " training updates in total)");
+ }
+ } finally {
+ // delete the temporary file and release resources
+ try {
+ dst.close(true);
+ } catch (IOException e) {
+ throw new HiveException("Failed to close a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ this.inputBuf = null;
+ this.fileIO = null;
+ }
+ }
+
+ protected void forwardModel() throws HiveException {
+ final IntWritable topicIdx = new IntWritable();
+ final Text word = new Text();
+ final FloatWritable score = new FloatWritable();
+
+ final Object[] forwardObjs = new Object[3];
+ forwardObjs[0] = topicIdx;
+ forwardObjs[1] = word;
+ forwardObjs[2] = score;
+
+ for (int k = 0; k < topic; k++) {
+ topicIdx.set(k);
+
+ final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ score.set(e.getKey());
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ word.set(words.get(i));
+ forward(forwardObjs);
+ }
+ }
+ }
+
+ logger.info("Forwarded topic words each of " + topic + " topics");
+ }
+
+ /*
+ * For testing:
+ */
+
+ @VisibleForTesting
+ double getLambda(String label, int k) {
+ return model.getLambda(label, k);
+ }
+
+ @VisibleForTesting
+ SortedMap<Float, List<String>> getTopicWords(int k) {
+ return model.getTopicWords(k);
+ }
+
+ @VisibleForTesting
+ SortedMap<Float, List<String>> getTopicWords(int k, int topN) {
+ return model.getTopicWords(k, topN);
+ }
+
+ @VisibleForTesting
+ float[] getTopicDistribution(@Nonnull String[] doc) {
+ return model.getTopicDistribution(doc);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
new file mode 100644
index 0000000..3e7ad10
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -0,0 +1,554 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.annotations.VisibleForTesting;
+import hivemall.model.FeatureValue;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.math.MathUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.math3.distribution.GammaDistribution;
+import org.apache.commons.math3.special.Gamma;
+
+public final class OnlineLDAModel {
+
+ // number of topics
+ private final int _K;
+
+ // prior on weight vectors "theta ~ Dir(alpha_)"
+ private final float _alpha;
+
+ // prior on topics "beta"
+ private final float _eta;
+
+ // total number of documents
+ // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen
+ private long _D = -1L;
+
+ // defined by (tau0 + updateCount)^(-kappa_)
+ // controls how much old lambda is forgotten
+ private double _rhot;
+
+ // positive value which downweights early iterations
+ @Nonnegative
+ private final double _tau0;
+
+ // exponential decay rate (i.e., learning rate) which must be in (0.5, 1] to guarantee convergence
+ private final double _kappa;
+
+ // how many times EM steps are launched; later EM steps do not drastically forget old lambda
+ private long _updateCount = 0L;
+
+ // random number generator
+ @Nonnull
+ private final GammaDistribution _gd;
+ private static final double SHAPE = 100.d;
+ private static final double SCALE = 1.d / SHAPE;
+
+ // parameters
+ @Nonnull
+ private List<Map<String, float[]>> _phi;
+ private float[][] _gamma;
+ @Nonnull
+ private final Map<String, float[]> _lambda;
+
+ // check convergence in the expectation (E) step
+ private final double _delta;
+
+ @Nonnull
+ private final List<Map<String, Float>> _miniBatchMap;
+ private int _miniBatchSize;
+
+ // for computing perplexity
+ private float _docRatio = 1.f;
+ private long _wordCount = 0L;
+
+ public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation
+ this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta);
+ }
+
+ public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa,
+ double delta) {
+ if (tau0 < 0.d) {
+ throw new IllegalArgumentException("tau0 MUST be positive: " + tau0);
+ }
+ if (kappa <= 0.5 || 1.d < kappa) {
+ throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa);
+ }
+
+ this._K = K;
+ this._alpha = alpha;
+ this._eta = eta;
+ this._D = D;
+ this._tau0 = tau0;
+ this._kappa = kappa;
+ this._delta = delta;
+
+ // initialize a random number generator
+ this._gd = new GammaDistribution(SHAPE, SCALE);
+ _gd.reseedRandomGenerator(1001);
+
+ // initialize the parameters
+ this._lambda = new HashMap<String, float[]>(100);
+
+ this._miniBatchMap = new ArrayList<Map<String, Float>>();
+ }
+
+ /**
+ * In a truly online setting, total number of documents corresponds to the number of documents
+ * that have ever seen. In that case, users need to manually set the current max number of documents
+ * via this method.
+ * Note that, since the same set of documents could be repeatedly passed to `train()`,
+ * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
+ */
+ public void setNumTotalDocs(@Nonnegative long D) {
+ this._D = D;
+ }
+
+ public void train(@Nonnull final String[][] miniBatch) {
+ if (_D <= 0L) {
+ throw new RuntimeException("Total number of documents MUST be set via `setNumTotalDocs()`");
+ }
+
+ preprocessMiniBatch(miniBatch);
+
+ initParams(true);
+
+ // Expectation
+ eStep();
+
+ this._rhot = Math.pow(_tau0 + _updateCount, -_kappa);
+
+ // Maximization
+ mStep();
+
+ _updateCount++;
+ }
+
+ private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) {
+ initMiniBatchMap(miniBatch, _miniBatchMap);
+
+ this._miniBatchSize = _miniBatchMap.size();
+
+ // accumulate the number of words for each documents
+ this._wordCount = 0L;
+ for (int d = 0; d < _miniBatchSize; d++) {
+ for (float n : _miniBatchMap.get(d).values()) {
+ this._wordCount += n;
+ }
+ }
+
+ this._docRatio = (float)((double) _D / _miniBatchSize);
+ }
+
+ private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
+ @Nonnull final List<Map<String, Float>> map) {
+ map.clear();
+
+ final FeatureValue probe = new FeatureValue();
+
+ // parse document
+ for (final String[] e : miniBatch) {
+ if (e == null) {
+ continue;
+ }
+
+ final Map<String, Float> docMap = new HashMap<String, Float>();
+
+ // parse features
+ for (String fv : e) {
+ if (fv == null) {
+ continue;
+ }
+ FeatureValue.parseFeatureAsString(fv, probe);
+ String label = probe.getFeatureAsString();
+ float value = probe.getValueAsFloat();
+ docMap.put(label, value);
+ }
+
+ map.add(docMap);
+ }
+ }
+
+ private void initParams(boolean gammaWithRandom) {
+ _phi = new ArrayList<Map<String, float[]>>();
+ _gamma = new float[_miniBatchSize][];
+
+ for (int d = 0; d < _miniBatchSize; d++) {
+ if (gammaWithRandom) {
+ _gamma[d] = ArrayUtils.newRandomFloatArray(_K, _gd);
+ } else {
+ _gamma[d] = ArrayUtils.newInstance(_K, 1.f);
+ }
+
+ final Map<String, float[]> phi_d = new HashMap<String, float[]>();
+ _phi.add(phi_d);
+ for (String label : _miniBatchMap.get(d).keySet()) {
+ phi_d.put(label, new float[_K]);
+ if (!_lambda.containsKey(label)) { // lambda for newly observed word
+ _lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd));
+ }
+ }
+ }
+ }
+
+ private void eStep() {
+ // since lambda is invariant in the expectation step,
+ // `digamma`s of lambda values for Elogbeta are pre-computed
+ final float[] lambdaSum = new float[_K];
+ final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>();
+ for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+ String label = e.getKey();
+ float[] lambda_label = e.getValue();
+
+ // for digamma(lambdaSum)
+ MathUtils.add(lambdaSum, lambda_label, _K);
+
+ float[] digamma_lambda_label = new float[_K];
+ digamma_lambda.put(label, MathUtils.digamma(lambda_label));
+ }
+ final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+
+ float[] gamma_d, gammaPrev_d;
+ Map<String, float[]> eLogBeta_d;
+
+ // for each of mini-batch documents, update gamma until convergence
+ for (int d = 0; d < _miniBatchSize; d++) {
+ gamma_d = _gamma[d];
+ eLogBeta_d = computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum);
+
+ do {
+ // (deep) copy the last gamma values
+ gammaPrev_d = gamma_d.clone();
+
+ updatePhiPerDoc(d, eLogBeta_d);
+ updateGammaPerDoc(d);
+ } while (!checkGammaDiff(gammaPrev_d, gamma_d));
+ }
+ }
+
+ @Nonnull
+ private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d,
+ @Nonnull Map<String, float[]> digamma_lambda, @Nonnull float[] digamma_lambdaSum) {
+ // Dirichlet expectation (2d) for lambda
+ final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
+ final Map<String, Float> doc = _miniBatchMap.get(d);
+
+ for (String label : doc.keySet()) {
+ float[] eLogBeta_label = eLogBeta_d.get(label);
+ if (eLogBeta_label == null) {
+ eLogBeta_label = new float[_K];
+ eLogBeta_d.put(label, eLogBeta_label);
+ }
+ final float[] digamma_lambda_label = digamma_lambda.get(label);
+ for (int k = 0; k < _K; k++) {
+ eLogBeta_label[k] = digamma_lambda_label[k] - digamma_lambdaSum[k];
+ }
+ }
+
+ return eLogBeta_d;
+ }
+
+ private void updatePhiPerDoc(@Nonnegative final int d, @Nonnull Map<String, float[]> eLogBeta_d) {
+ // Dirichlet expectation (2d) for gamma
+ final float[] eLogTheta_d = new float[_K];
+ final float[] gamma_d = _gamma[d];
+ final float digamma_gammaSum_d = (float) Gamma.digamma(MathUtils.sum(gamma_d));
+ for (int k = 0; k < _K; k++) {
+ eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
+ }
+
+ // updating phi w/ normalization
+ final Map<String, float[]> phi_d = _phi.get(d);
+ final Map<String, Float> doc = _miniBatchMap.get(d);
+ for (String label : doc.keySet()) {
+ final float[] phi_label = phi_d.get(label);
+ final float[] eLogBeta_label = eLogBeta_d.get(label);
+
+ float normalizer = 0.f;
+ for (int k = 0; k < _K; k++) {
+ float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f;
+ phi_label[k] = phiVal;
+ normalizer += phiVal;
+ }
+
+ // normalize
+ for (int k = 0; k < _K; k++) {
+ phi_label[k] /= normalizer;
+ }
+ }
+ }
+
+ private void updateGammaPerDoc(@Nonnegative final int d) {
+ final Map<String, Float> doc = _miniBatchMap.get(d);
+ final Map<String, float[]> phi_d = _phi.get(d);
+
+ final float[] gamma_d = _gamma[d];
+ for (int k = 0; k < _K; k++) {
+ gamma_d[k] = _alpha;
+ }
+ for (Map.Entry<String, Float> e : doc.entrySet()) {
+ final float[] phi_label = phi_d.get(e.getKey());
+ final float val = e.getValue();
+ for (int k = 0; k < _K; k++) {
+ gamma_d[k] += phi_label[k] * val;
+ }
+ }
+ }
+
+ private boolean checkGammaDiff(@Nonnull final float[] gammaPrev,
+ @Nonnull final float[] gammaNext) {
+ double diff = 0.d;
+ for (int k = 0; k < _K; k++) {
+ diff += Math.abs(gammaPrev[k] - gammaNext[k]);
+ }
+ return (diff / _K) < _delta;
+ }
+
+ private void mStep() {
+ // calculate lambdaTilde for vocabularies in the current mini-batch
+ final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>();
+ for (int d = 0; d < _miniBatchSize; d++) {
+ final Map<String, float[]> phi_d = _phi.get(d);
+ for (String label : _miniBatchMap.get(d).keySet()) {
+ float[] lambdaTilde_label = lambdaTilde.get(label);
+ if (lambdaTilde_label == null) {
+ lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+ lambdaTilde.put(label, lambdaTilde_label);
+ }
+
+ final float[] phi_label = phi_d.get(label);
+ for (int k = 0; k < _K; k++) {
+ lambdaTilde_label[k] += _docRatio * phi_label[k];
+ }
+ }
+ }
+
+ // update lambda for all vocabularies
+ for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+ String label = e.getKey();
+ final float[] lambda_label = e.getValue();
+
+ float[] lambdaTilde_label = lambdaTilde.get(label);
+ if (lambdaTilde_label == null) {
+ lambdaTilde_label = ArrayUtils.newInstance(_K, _eta);
+ }
+
+ for (int k = 0; k < _K; k++) {
+ lambda_label[k] = (float) ((1.d - _rhot) * lambda_label[k] + _rhot
+ * lambdaTilde_label[k]);
+ }
+ }
+ }
+
+ /**
+ * Calculate approximate perplexity for the current mini-batch.
+ */
+ public float computePerplexity() {
+ float bound = computeApproxBound();
+ float perWordBound = bound / (_docRatio * _wordCount);
+ return (float) Math.exp(-1.f * perWordBound);
+ }
+
+ /**
+ * Estimates the variational bound over all documents using only the documents passed as mini-batch.
+ */
+ private float computeApproxBound() {
+ float score = 0.f;
+
+ // prepare
+ final float[] gammaSum = new float[_miniBatchSize];
+ for (int d = 0; d < _miniBatchSize; d++) {
+ gammaSum[d] = MathUtils.sum(_gamma[d]);
+ }
+ final float[] digamma_gammaSum = MathUtils.digamma(gammaSum);
+
+ final float[] lambdaSum = new float[_K];
+ for (float[] lambda_label : _lambda.values()) {
+ MathUtils.add(lambdaSum, lambda_label, _K);
+ }
+ final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+
+ final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
+ final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
+
+ for (int d = 0; d < _miniBatchSize; d++) {
+ final float digamma_gammaSum_d = digamma_gammaSum[d];
+
+ // E[log p(doc | theta, beta)]
+ for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) {
+ final float[] lambda_label = _lambda.get(e.getKey());
+
+ // logsumexp( Elogthetad + Elogbetad )
+ final float[] temp = new float[_K];
+ float max = Float.MIN_VALUE;
+ for (int k = 0; k < _K; k++) {
+ final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k]) - digamma_gammaSum_d;
+ final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+
+ temp[k] = eLogTheta_dk + eLogBeta_kw;
+ if (temp[k] > max) {
+ max = temp[k];
+ }
+ }
+ float logsumexp = 0.f;
+ for (int k = 0; k < _K; k++) {
+ logsumexp += (float) Math.exp(temp[k] - max);
+ }
+ logsumexp = max + (float) Math.log(logsumexp);
+
+ // sum( word count * logsumexp(...) )
+ score += e.getValue() * logsumexp;
+ }
+
+ // E[log p(theta | alpha) - log q(theta | gamma)]
+ for (int k = 0; k < _K; k++) {
+ final float gamma_dk = _gamma[d][k];
+
+ // sum( (alpha - gammad) * Elogthetad )
+ score += (_alpha - gamma_dk)
+ * ((float) Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
+
+ // sum( gammaln(gammad) - gammaln(alpha) )
+ score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha;
+ }
+ score += logGamma_alphaSum; // gammaln(sum(alpha))
+ score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad))
+ }
+
+ // assuming likelihood for when corpus in the documents is only a subset of the whole corpus
+ // (i.e., online setting); likelihood should be always roughly on the same scale
+ score *= _docRatio;
+
+ final float logGamma_eta = (float) Gamma.logGamma(_eta);
+ final float logGamma_etaSum = (float) Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta
+
+ // E[log p(beta | eta) - log q (beta | lambda)]
+ for (float[] lambda_label : _lambda.values()) {
+ for (int k = 0; k < _K; k++) {
+ final float lambda_k = lambda_label[k];
+
+ // sum( (eta - lambda) * Elogbeta )
+ score += (_eta - lambda_k)
+ * (float) (Gamma.digamma(lambda_k) - digamma_lambdaSum[k]);
+
+ // sum( gammaln(lambda) - gammaln(eta) )
+ score += (float) Gamma.logGamma(lambda_k) - logGamma_eta;
+ }
+ }
+ for (int k = 0; k < _K; k++) {
+ // sum( gammaln(etaSum) - gammaln( lambdaSum_k )
+ score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]);
+ }
+
+ return score;
+ }
+
+ @VisibleForTesting
+ double getLambda(@Nonnull final String label, @Nonnegative final int k) {
+ final float[] lambda_label = _lambda.get(label);
+ if (lambda_label == null) {
+ throw new IllegalArgumentException("Word `" + label + "` is not in the corpus.");
+ }
+ if (k >= lambda_label.length) {
+ throw new IllegalArgumentException("Topic index must be in [0, "
+ + _lambda.get(label).length + "]");
+ }
+ return lambda_label[k];
+ }
+
+ public void setLambda(@Nonnull final String label, @Nonnegative final int k, final float lambda_k) {
+ float[] lambda_label = _lambda.get(label);
+ if (lambda_label == null) {
+ lambda_label = ArrayUtils.newRandomFloatArray(_K, _gd);
+ _lambda.put(label, lambda_label);
+ }
+ lambda_label[k] = lambda_k;
+ }
+
+ @Nonnull
+ public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
+ return getTopicWords(k, _lambda.keySet().size());
+ }
+
+ @Nonnull
+ public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k,
+ @Nonnegative int topN) {
+ float lambdaSum = 0.f;
+ final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>(
+ Collections.reverseOrder());
+
+ for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
+ final float lambda_k = e.getValue()[k];
+ lambdaSum += lambda_k;
+
+ List<String> labels = sortedLambda.get(lambda_k);
+ if (labels == null) {
+ labels = new ArrayList<String>();
+ sortedLambda.put(lambda_k, labels);
+ }
+ labels.add(e.getKey());
+ }
+
+ final SortedMap<Float, List<String>> ret = new TreeMap<Float, List<String>>(
+ Collections.reverseOrder());
+
+ topN = Math.min(topN, _lambda.keySet().size());
+ int tt = 0;
+ for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) {
+ ret.put(e.getKey() / lambdaSum, e.getValue());
+
+ if (++tt == topN) {
+ break;
+ }
+ }
+
+ return ret;
+ }
+
+ @Nonnull
+ public float[] getTopicDistribution(@Nonnull final String[] doc) {
+ preprocessMiniBatch(new String[][] {doc});
+
+ initParams(false);
+
+ eStep();
+
+ // normalize topic distribution
+ final float[] topicDistr = new float[_K];
+ final float[] gamma0 = _gamma[0];
+ final float gammaSum = MathUtils.sum(gamma0);
+ for (int k = 0; k < _K; k++) {
+ topicDistr[k] = gamma0[k] / gammaSum;
+ }
+ return topicDistr;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index e8e337d..711aac7 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -23,9 +23,12 @@ import java.util.Arrays;
import java.util.List;
import java.util.Random;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.commons.math3.distribution.GammaDistribution;
+
public final class ArrayUtils {
/**
@@ -715,4 +718,21 @@ public final class ArrayUtils {
return cnt;
}
+ @Nonnull
+ public static float[] newInstance(@Nonnegative int size, float filledValue) {
+ final float[] a = new float[size];
+ Arrays.fill(a, filledValue);
+ return a;
+ }
+
+ @Nonnull
+ public static float[] newRandomFloatArray(@Nonnegative final int size,
+ @Nonnull final GammaDistribution gd) {
+ final float[] ret = new float[size];
+ for (int i = 0; i < size; i++) {
+ ret[i] = (float) gd.sample();
+ }
+ return ret;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index b71d165..7fdea55 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -38,6 +38,9 @@ import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.math3.special.Gamma;
public final class MathUtils {
@@ -311,4 +314,44 @@ public final class MathUtils {
return perm;
}
+ public static float sum(@Nullable final float[] a) {
+ if (a == null) {
+ return 0.f;
+ }
+
+ float sum = 0.f;
+ for (float v : a) {
+ sum += v;
+ }
+ return sum;
+ }
+
+ public static float sum(@Nullable final float[] a, @Nonnegative final int size) {
+ if (a == null) {
+ return 0.f;
+ }
+
+ float sum = 0.f;
+ for (int i = 0; i < size; i++) {
+ sum += a[i];
+ }
+ return sum;
+ }
+
+ public static void add(@Nonnull final float[] dst, @Nonnull final float[] toAdd, final int size) {
+ for (int i = 0; i < size; i++) {
+ dst[i] += toAdd[i];
+ }
+ }
+
+ @Nonnull
+ public static float[] digamma(@Nonnull final float[] a) {
+ final int k = a.length;
+ final float[] ret = new float[k];
+ for (int i = 0; i < k; i++) {
+ ret[i] = (float) Gamma.digamma(a[i]);
+ }
+ return ret;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
new file mode 100644
index 0000000..a23d917
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class LDAPredictUDAFTest {
+ LDAPredictUDAF udaf;
+ GenericUDAFEvaluator evaluator;
+ ObjectInspector[] inputOIs;
+ ObjectInspector[] partialOI;
+ LDAPredictUDAF.OnlineLDAPredictAggregationBuffer agg;
+
+ String[] words;
+ int[] labels;
+ float[] lambdas;
+
+ @Test(expected=UDFArgumentException.class)
+ public void testWithoutOption() throws Exception {
+ udaf = new LDAPredictUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.STRING),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.INT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+ evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ }
+
+ @Test(expected=UDFArgumentException.class)
+ public void testWithoutTopicOption() throws Exception {
+ udaf = new LDAPredictUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.STRING),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.INT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
+
+ evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ udaf = new LDAPredictUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.STRING),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.INT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")};
+
+ evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("wcList");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+ fieldNames.add("lambdaMap");
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+ fieldNames.add("topic");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+ fieldNames.add("alpha");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ fieldNames.add("delta");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+ partialOI = new ObjectInspector[4];
+ partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+
+ agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+ words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", "like", "avocados", "colds",
+ "colds", "avocados", "oranges", "like", "apples", "flu", "healthy", "vegetables", "fruits"};
+ labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+ lambdas = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f, 2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f,
+ 0.1660153f, 0.16596903f, 0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f, 8.84464E-4f};
+ }
+
+ @Test
+ public void test() throws Exception {
+ final Map<String, Float> doc1 = new HashMap<String, Float>();
+ doc1.put("fruits", 1.f);
+ doc1.put("healthy", 1.f);
+ doc1.put("vegetables", 1.f);
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 0; i < words.length; i++) {
+ String word = words[i];
+ evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], lambdas[i]});
+ }
+ float[] doc1Distr = agg.get();
+
+ final Map<String, Float> doc2 = new HashMap<String, Float>();
+ doc2.put("apples", 1.f);
+ doc2.put("avocados", 1.f);
+ doc2.put("colds", 1.f);
+ doc2.put("flu", 1.f);
+ doc2.put("like", 2.f);
+ doc2.put("oranges", 1.f);
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 0; i < words.length; i++) {
+ String word = words[i];
+ evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], lambdas[i]});
+ }
+ float[] doc2Distr = agg.get();
+
+ Assert.assertTrue(doc1Distr[0] > doc2Distr[0]);
+ Assert.assertTrue(doc1Distr[1] < doc2Distr[1]);
+ }
+
+
+ @Test
+ public void testMerge() throws Exception {
+ final Map<String, Float> doc = new HashMap<String, Float>();
+ doc.put("apples", 1.f);
+ doc.put("avocados", 1.f);
+ doc.put("colds", 1.f);
+ doc.put("flu", 1.f);
+ doc.put("like", 2.f);
+ doc.put("oranges", 1.f);
+
+ Object[] partials = new Object[3];
+
+ // bin #1
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 0; i < 6; i++) {
+ evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]});
+ }
+ partials[0] = evaluator.terminatePartial(agg);
+
+ // bin #2
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 6; i < 12; i++) {
+ evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]});
+ }
+ partials[1] = evaluator.terminatePartial(agg);
+
+ // bin #3
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+ for (int i = 12; i < 18; i++) {
+ evaluator.iterate(agg, new Object[]{words[i], doc.get(words[i]), labels[i], lambdas[i]});
+ }
+
+ partials[2] = evaluator.terminatePartial(agg);
+
+ // merge in a different order
+ final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}};
+ for (int i = 0; i < orders.length; i++) {
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+ evaluator.reset(agg);
+
+ evaluator.merge(agg, partials[orders[i][0]]);
+ evaluator.merge(agg, partials[orders[i][1]]);
+ evaluator.merge(agg, partials[orders[i][2]]);
+
+ float[] distr = agg.get();
+ Assert.assertTrue(distr[0] < distr[1]);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9b2ddcc7/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
new file mode 100644
index 0000000..d1e3f81
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Arrays;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LDAUDTFTest {
+ private static final boolean DEBUG = false;
+
+ @Test
+ public void test() throws HiveException {
+ LDAUDTF udtf = new LDAUDTF();
+
+ ObjectInspector[] argOIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2 -s 1")};
+
+ udtf.initialize(argOIs);
+
+ String[] doc1 = new String[]{"fruits:1", "healthy:1", "vegetables:1"};
+ String[] doc2 = new String[]{"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", "oranges:1"};
+ for (int it = 0; it < 5; it++) {
+ udtf.process(new Object[]{ Arrays.asList(doc1) });
+ udtf.process(new Object[]{ Arrays.asList(doc2) });
+ }
+
+ SortedMap<Float, List<String>> topicWords;
+
+ println("Topic 0:");
+ println("========");
+ topicWords = udtf.getTopicWords(0);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ println(e.getKey() + " " + words.get(i));
+ }
+ }
+ println("========");
+
+ println("Topic 1:");
+ println("========");
+ topicWords = udtf.getTopicWords(1);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ println(e.getKey() + " " + words.get(i));
+ }
+ }
+ println("========");
+
+ int k1, k2;
+ float[] topicDistr = udtf.getTopicDistribution(doc1);
+ if (topicDistr[0] > topicDistr[1]) {
+ // topic 0 MUST represent doc#1
+ k1 = 0;
+ k2 = 1;
+ } else {
+ k1 = 1;
+ k2 = 0;
+ }
+
+ Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+ udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1));
+ Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+ udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2));
+ }
+
+ private static void println(String msg) {
+ if (DEBUG) {
+ System.out.println(msg);
+ }
+ }
+}