You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/27 13:45:43 UTC

[1/2] incubator-hivemall git commit: Close #71: [HIVEMALL-74] Implement pLSA

Repository: incubator-hivemall
Updated Branches:
  refs/heads/master bffd2c78d -> f2bf3a72b


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
new file mode 100644
index 0000000..456dd1d
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class PLSAPredictUDAFTest {
+    PLSAPredictUDAF udaf;
+    GenericUDAFEvaluator evaluator;
+    ObjectInspector[] inputOIs;
+    ObjectInspector[] partialOI;
+    PLSAPredictUDAF.PLSAPredictAggregationBuffer agg;
+
+    String[] words;
+    int[] labels;
+    float[] probs;
+
+    @Test(expected = UDFArgumentException.class)
+    public void testWithoutOption() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+    }
+
+    @Test(expected = UDFArgumentException.class)
+    public void testWithoutTopicOption() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+    }
+
+    @Before
+    public void setUp() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+        fieldNames.add("wcList");
+        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+        fieldNames.add("probMap");
+        fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+        fieldNames.add("topics");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+        fieldNames.add("alpha");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        fieldNames.add("delta");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+        partialOI = new ObjectInspector[4];
+        partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+        words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges",
+                "like", "avocados", "colds", "colds", "avocados", "oranges", "like", "apples",
+                "flu", "healthy", "vegetables", "fruits"};
+        labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+        probs = new float[] {0.3339331f, 0.3324783f, 0.33209667f, 3.2804057E-4f, 3.0303953E-4f,
+                2.4860457E-4f, 2.41481E-4f, 2.3554532E-4f, 1.352576E-4f, 0.1660153f, 0.16596903f,
+                0.1659654f, 0.1659627f, 0.16593699f, 0.1659259f, 0.0017611005f, 0.0015791848f,
+                8.84464E-4f};
+    }
+
+    @Test
+    public void test() throws Exception {
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], probs[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], probs[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertTrue(doc1Distr[0] > doc2Distr[0]);
+        Assert.assertTrue(doc1Distr[1] < doc2Distr[1]);
+    }
+
+    @Test
+    public void testMerge() throws Exception {
+        final Map<String, Float> doc = new HashMap<String, Float>();
+        doc.put("apples", 1.f);
+        doc.put("avocados", 1.f);
+        doc.put("colds", 1.f);
+        doc.put("flu", 1.f);
+        doc.put("like", 2.f);
+        doc.put("oranges", 1.f);
+
+        Object[] partials = new Object[3];
+
+        // bin #1
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < 6; i++) {
+            evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
+        }
+        partials[0] = evaluator.terminatePartial(agg);
+
+        // bin #2
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 6; i < 12; i++) {
+            evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
+        }
+        partials[1] = evaluator.terminatePartial(agg);
+
+        // bin #3
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 12; i < 18; i++) {
+            evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
+        }
+
+        partials[2] = evaluator.terminatePartial(agg);
+
+        // merge in a different order
+        final int[][] orders = new int[][] { {0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}};
+        for (int i = 0; i < orders.length; i++) {
+            evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+            evaluator.reset(agg);
+
+            evaluator.merge(agg, partials[orders[i][0]]);
+            evaluator.merge(agg, partials[orders[i][1]]);
+            evaluator.merge(agg, partials[orders[i][2]]);
+
+            float[] distr = agg.get();
+            Assert.assertTrue(distr[0] < distr[1]);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
new file mode 100644
index 0000000..76795bc
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Arrays;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class PLSAUDTFTest {
+    private static final boolean DEBUG = false;
+
+    @Test
+    public void test() throws HiveException {
+        PLSAUDTF udtf = new PLSAUDTF();
+
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    "-topics 2 -alpha 0.1 -delta 0.00001")};
+
+        udtf.initialize(argOIs);
+
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
+                "oranges:1"};
+        for (int it = 0; it < 10000; it++) {
+            udtf.process(new Object[] {Arrays.asList(doc1)});
+            udtf.process(new Object[] {Arrays.asList(doc2)});
+        }
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = udtf.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = udtf.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        int k1, k2;
+        float[] topicDistr = udtf.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+                + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+            udtf.getProbability("vegetables", k1) > udtf.getProbability("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+                + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+            udtf.getProbability("avocados", k2) > udtf.getProbability("healthy", k2));
+    }
+
+    private static void println(String msg) {
+        if (DEBUG) {
+            System.out.println(msg);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 695119a..3d035d7 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -153,6 +153,7 @@
 ## Part X - Clustering
 
 * [Latent Dirichlet Allocation](clustering/lda.md)
+* [Probabilistic Latent Semantic Analysis](clustering/plsa.md)
 
 ## Part XI - GeoSpatial functions
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/docs/gitbook/clustering/plsa.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/plsa.md b/docs/gitbook/clustering/plsa.md
new file mode 100644
index 0000000..456dfe7
--- /dev/null
+++ b/docs/gitbook/clustering/plsa.md
@@ -0,0 +1,154 @@
+<!--
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing,
+  software distributed under the License is distributed on an
+  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+  KIND, either express or implied.  See the License for the
+  specific language governing permissions and limitations
+  under the License.
+-->
+
+As described in [our user guide for Latent Dirichlet Allocation (LDA)](lda.md), Hivemall enables you to apply clustering for your data based on a topic modeling technique. While LDA is one of the most popular techniques, there is another approach named **Probabilistic Latent Semantic Analysis** (pLSA). In fact, pLSA is the predecessor of LDA, but it has an advantage in terms of running time.
+
+- T. Hofmann. [Probabilistic Latent Semantic Indexing](http://dl.acm.org/citation.cfm?id=312649). SIGIR 1999, pp. 50-57.
+- T. Hofmann. [Probabilistic Latent Semantic Analysis](http://www.iro.umontreal.ca/~nie/IFT6255/Hofmann-UAI99.pdf). UAI 1999, pp. 289-296.
+
+In order to efficiently handle large-scale data, our pLSA implementation is based on the following incremental variant of the original pLSA algorithm:
+
+- H. Wu, et al. [Incremental Probabilistic Latent Semantic Analysis for Automatic Question Recommendation](http://dl.acm.org/citation.cfm?id=1454026). RecSys 2008, pp. 99-106.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.1 or later.
+
+# Usage
+
+Basically, you can use our pLSA function in a similar way to LDA.
+
+In particular, we have two pLSA functions, `train_plsa()` and `plsa_predict()`. These functions can be used almost interchangeably with `train_lda()` and `lda_predict()`. Thus, reading [our user guide for LDA](lda.md) should be helpful before trying pLSA.
+
+In short, for the sample `docs` table we introduced in the LDA tutorial:
+
+| docid | doc  |
+|:---:|:---|
+| 1  | "Fruits and vegetables are healthy." |
+|2 | "I like apples, oranges, and avocados. I do not like the flu or colds." |
+| ... | ... |
+
+a pLSA model can be built as follows:
+
+```sql
+with word_counts as (
+  select
+    docid,
+    feature(word, count(word)) as f
+  from docs t1 lateral view explode(tokenize(doc, true)) t2 as word
+  where
+    not is_stopword(word)
+  group by
+    docid, word
+)
+select
+	train_plsa(feature, "-topics 2 -eps 0.00001 -iter 2048 -alpha 0.01") as (label, word, prob)
+from (
+  select docid, collect_set(f) as feature
+  from word_counts
+  group by docid
+) t
+;
+```
+
+|label |  word  |  prob|
+|:---:|:---:|:---:|
+|0|       like   | 0.28549945|
+|0|       colds  | 0.14294468|
+|0|       apples | 0.14291435|
+|0|       avocados|        0.1428958|
+|0|       flu    | 0.14287639|
+|0|       oranges| 0.1428691|
+|0|       healthy| 1.2605103E-7|
+|0|       fruits | 4.772253E-8|
+|0|       vegetables |     1.929087E-8|
+|1|       vegetables  |    0.32713377|
+|1|       fruits | 0.32713372|
+|1|       healthy| 0.3271335|
+|1|       like   | 0.006977764|
+|1|       oranges| 0.0025642214|
+|1|       flu    | 0.002507711|
+|1|       avocados|        0.0023572792|
+|1|       apples | 0.002213457|
+|1|       colds  | 0.001978546|
+
+
+
+And prediction can be done as:
+
+```sql
+test as (
+  select
+    docid,
+    word,
+    count(word) as value
+  from docs t1 LATERAL VIEW explode(tokenize(doc, true)) t2 as word
+  where
+    not is_stopword(word)
+  group by
+    docid, word
+),
+topic as (
+  select
+    t.docid,
+    plsa_predict(t.word, t.value, m.label, m.prob, "-topics 2") as probabilities
+  from
+    test t
+    JOIN plsa_model m ON (t.word = m.word)
+  group by
+    t.docid
+)
+select docid, probabilities, probabilities[0].label, m.words -- topic each document should be assigned
+from topic t
+join (
+  select label, collect_set(feature(word, prob)) as words
+  from plsa_model
+  group by label
+) m on t.probabilities[0].label = m.label
+;
+```
+
+
+|docid  | probabilities |  label |  m.words |
+|:---:|:---|:---:|:---|
+|1      | [{"label":1,"probability":0.72298235},{"label":0,"probability":0.27701768}]   |  1 |      ["vegetables:0.32713377","fruits:0.32713372","healthy:0.3271335","like:0.006977764","oranges:0.0025642214","flu:0.002507711","avocados:0.0023572792","apples:0.002213457","colds:0.001978546"]|
+|2  |     [{"label":0,"probability":0.7052526},{"label":1,"probability":0.2947474}]     |  0     |  ["like:0.28549945","colds:0.14294468","apples:0.14291435","avocados:0.1428958","flu:0.14287639","oranges:0.1428691","healthy:1.2605103E-7","fruits:4.772253E-8","vegetables:1.929087E-8"]|
+
+# Difference with LDA
+
+The main advantage of using pLSA is its efficiency. Since mathematical formulation and optimization logic is much simpler than LDA, using pLSA generally requires much shorter running time.
+
+In terms of accuracy, LDA could be better than pLSA. For example, a word `like` appears twice in the above sample document#2 gets larger probabilities both in topic#1 and #2, even though one document does not contain the word. By contrast, LDA results (i.e., *lambda* values) are more clearly separated as shown in [the LDA page](lda.md). Thus, a pLSA model is likely to be biased.
+
+For the reasons that we mentioned above, we recommend you to first use LDA. After that, if you encountered problems such as slow running time and undesirable clustering results, let you try alternative pLSA approach.
+
+# Setting hyper-parameter `alpha`
+
+For training pLSA, we set a hyper-parameter `alpha` in the above example:
+
+```sql
+SELECT train_plsa(feature, "-topics 2 -eps 0.00001 -iter 2048 -alpha 0.01") 
+```
+
+This value controls **how much iterative model update is affected by the old results**.
+
+From an algorithmic point of view, training pLSA (and LDA) iteratively repeats certain operations and updates the target value (i.e., probability obtained as a result of `train_plsa()`). This iterative procedure gradually makes the probabilities more accurate. What `alpha` does is to control the degree of the change of probabilities in each step.
+
+Normally, `alpha` is set to a small value from 0.0 to 0.5 (default is 0.5).
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 435466d..425d8ff 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -626,6 +626,12 @@ CREATE FUNCTION train_lda as 'hivemall.topicmodel.LDAUDTF' USING JAR '${hivemall
 DROP FUNCTION IF EXISTS lda_predict;
 CREATE FUNCTION lda_predict as 'hivemall.topicmodel.LDAPredictUDAF' USING JAR '${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS train_plsa;
+CREATE FUNCTION train_plsa as 'hivemall.topicmodel.PLSAUDTF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS plsa_predict;
+CREATE FUNCTION plsa_predict as 'hivemall.topicmodel.PLSAPredictUDAF' USING JAR '${hivemall_jar}';
+
 ---------------------------
 -- Geo-Spatial functions --
 ---------------------------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 8982ef4..d283812 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -622,6 +622,12 @@ create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF';
 drop temporary function if exists lda_predict;
 create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF';
 
+drop temporary function if exists train_plsa;
+create temporary function train_plsa as 'hivemall.topicmodel.PLSAUDTF';
+
+drop temporary function if exists plsa_predict;
+create temporary function plsa_predict as 'hivemall.topicmodel.PLSAPredictUDAF';
+
 ---------------------------
 -- Geo-Spatial functions --
 ---------------------------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index a6473db..1b90c9b 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -606,6 +606,12 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION train_lda AS 'hivemall.topicmodel.LDAU
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS lda_predict")
 sqlContext.sql("CREATE TEMPORARY FUNCTION lda_predict AS 'hivemall.topicmodel.LDAPredictUDAF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_plsa")
+sqlContext.sql("CREATE TEMPORARY FUNCTION train_plsa AS 'hivemall.topicmodel.PLSAUDTF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS plsa_predict")
+sqlContext.sql("CREATE TEMPORARY FUNCTION plsa_predict AS 'hivemall.topicmodel.PLSAPredictUDAF'")
+
 /**
  * Geo Spatial Functions
  */

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index a2e5838..e549649 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -160,6 +160,8 @@ create temporary function changefinder as 'hivemall.anomaly.ChangeFinderUDF';
 create temporary function sst as 'hivemall.anomaly.SingularSpectrumTransformUDF';
 create temporary function train_lda as 'hivemall.topicmodel.LDAUDTF';
 create temporary function lda_predict as 'hivemall.topicmodel.LDAPredictUDAF';
+create temporary function train_plsa as 'hivemall.topicmodel.PLSAUDTF';
+create temporary function plsa_predict as 'hivemall.topicmodel.PLSAPredictUDAF';
 create temporary function tile as 'hivemall.geospatial.TileUDF';
 create temporary function map_url as 'hivemall.geospatial.MapURLUDF';
 


[2/2] incubator-hivemall git commit: Close #71: [HIVEMALL-74] Implement pLSA

Posted by my...@apache.org.
Close #71: [HIVEMALL-74] Implement pLSA


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f2bf3a72
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f2bf3a72
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f2bf3a72

Branch: refs/heads/master
Commit: f2bf3a72b2f8deb0835feed649369c885a23053c
Parents: bffd2c7
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Thu Apr 27 22:44:44 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 27 22:44:44 2017 +0900

----------------------------------------------------------------------
 .../topicmodel/IncrementalPLSAModel.java        | 316 +++++++++++
 .../hivemall/topicmodel/LDAPredictUDAF.java     |  24 +-
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  22 +-
 .../hivemall/topicmodel/PLSAPredictUDAF.java    | 480 +++++++++++++++++
 .../main/java/hivemall/topicmodel/PLSAUDTF.java | 535 +++++++++++++++++++
 .../java/hivemall/utils/lang/ArrayUtils.java    |  10 +
 .../java/hivemall/utils/math/MathUtils.java     |  12 +
 .../topicmodel/IncrementalPLSAModelTest.java    | 291 ++++++++++
 .../hivemall/topicmodel/LDAPredictUDAFTest.java |   4 +-
 .../java/hivemall/topicmodel/LDAUDTFTest.java   |   2 +-
 .../topicmodel/PLSAPredictUDAFTest.java         | 217 ++++++++
 .../java/hivemall/topicmodel/PLSAUDTFTest.java  | 106 ++++
 docs/gitbook/SUMMARY.md                         |   1 +
 docs/gitbook/clustering/plsa.md                 | 154 ++++++
 resources/ddl/define-all-as-permanent.hive      |   6 +
 resources/ddl/define-all.hive                   |   6 +
 resources/ddl/define-all.spark                  |   6 +
 resources/ddl/define-udfs.td.hql                |   2 +
 18 files changed, 2168 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
new file mode 100644
index 0000000..745e510
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
+import static hivemall.utils.math.MathUtils.l1normalize;
+import hivemall.model.FeatureValue;
+import hivemall.utils.math.MathUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class IncrementalPLSAModel {
+
+    // ---------------------------------
+    // HyperParameters
+
+    // number of topics
+    private final int _K;
+
+    // control how much P(w|z) update is affected by the last value
+    private final float _alpha;
+
+    // check convergence of P(w|z) for a document
+    private final double _delta;
+
+    // ---------------------------------
+
+    // random number generator
+    @Nonnull
+    private final Random _rnd;
+
+    // optimized in the E step
+    private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
+
+    // optimized in the M step
+    @Nonnull
+    private List<float[]> _p_dz; // P(z|d) probability of topics for documents
+    private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+    @Nonnull
+    private final List<Map<String, Float>> _miniBatchDocs;
+    private int _miniBatchSize;
+
+    public IncrementalPLSAModel(int K, float alpha, double delta) {
+        this._K = K;
+        this._alpha = alpha;
+        this._delta = delta;
+
+        this._rnd = new Random(1001);
+
+        this._p_zw = new HashMap<String, float[]>();
+
+        this._miniBatchDocs = new ArrayList<Map<String, Float>>();
+    }
+
+    public void train(@Nonnull final String[][] miniBatch) {
+        initMiniBatch(miniBatch, _miniBatchDocs);
+
+        this._miniBatchSize = _miniBatchDocs.size();
+
+        initParams();
+
+        final List<float[]> pPrev_dz = new ArrayList<float[]>();
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            do {
+                pPrev_dz.clear();
+                pPrev_dz.addAll(_p_dz);
+
+                // Expectation
+                eStep(d);
+
+                // Maximization
+                mStep(d);
+            } while (!isPdzConverged(d, pPrev_dz, _p_dz)); // until get stable value of P(z|d)
+        }
+    }
+
+    private static void initMiniBatch(@Nonnull final String[][] miniBatch,
+            @Nonnull final List<Map<String, Float>> docs) {
+        docs.clear();
+
+        final FeatureValue probe = new FeatureValue();
+
+        // parse document
+        for (final String[] e : miniBatch) {
+            if (e == null || e.length == 0) {
+                continue;
+            }
+
+            final Map<String, Float> doc = new HashMap<String, Float>();
+
+            // parse features
+            for (String fv : e) {
+                if (fv == null) {
+                    continue;
+                }
+                FeatureValue.parseFeatureAsString(fv, probe);
+                String word = probe.getFeatureAsString();
+                float value = probe.getValueAsFloat();
+                doc.put(word, Float.valueOf(value));
+            }
+
+            docs.add(doc);
+        }
+    }
+
+    private void initParams() {
+        final List<float[]> p_dz = new ArrayList<float[]>();
+        final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>();
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            // init P(z|d)
+            float[] p_dz_d = l1normalize(newRandomFloatArray(_K, _rnd));
+            p_dz.add(p_dz_d);
+
+            final Map<String, float[]> p_dwz_d = new HashMap<String, float[]>();
+            p_dwz.add(p_dwz_d);
+
+            for (final String w : _miniBatchDocs.get(d).keySet()) {
+                // init P(z|d,w)
+                float[] p_dwz_dw = l1normalize(newRandomFloatArray(_K, _rnd));
+                p_dwz_d.put(w, p_dwz_dw);
+
+                // insert new labels to P(w|z)
+                if (!_p_zw.containsKey(w)) {
+                    _p_zw.put(w, newRandomFloatArray(_K, _rnd));
+                }
+            }
+        }
+
+        // ensure \sum_w P(w|z) = 1
+        final double[] sums = new double[_K];
+        for (float[] p_zw_w : _p_zw.values()) {
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int z = 0; z < _K; z++) {
+                p_zw_w[z] /= sums[z];
+            }
+        }
+
+        this._p_dz = p_dz;
+        this._p_dwz = p_dwz;
+    }
+
+    private void eStep(@Nonnegative final int d) {
+        final Map<String, float[]> p_dwz_d = _p_dwz.get(d);
+        final float[] p_dz_d = _p_dz.get(d);
+
+        // update P(z|d,w) = P(z|d) * P(w|z)
+        for (final String w : _miniBatchDocs.get(d).keySet()) {
+            final float[] p_dwz_dw = p_dwz_d.get(w);
+            final float[] p_zw_w = _p_zw.get(w);
+            for (int z = 0; z < _K; z++) {
+                p_dwz_dw[z] = p_dz_d[z] * p_zw_w[z];
+            }
+            l1normalize(p_dwz_dw);
+        }
+    }
+
+    private void mStep(@Nonnegative final int d) {
+        final Map<String, Float> doc = _miniBatchDocs.get(d);
+        final Map<String, float[]> p_dwz_d = _p_dwz.get(d);
+
+        // update P(z|d) = n(d,w) * P(z|d,w)
+        final float[] p_dz_d = _p_dz.get(d);
+        Arrays.fill(p_dz_d, 0.f); // zero-fill w/ keeping pointer to _p_dz.get(d)
+        for (Map.Entry<String, Float> e : doc.entrySet()) {
+            final float[] p_dwz_dw = p_dwz_d.get(e.getKey());
+            final float n = e.getValue().floatValue();
+            for (int z = 0; z < _K; z++) {
+                p_dz_d[z] += n * p_dwz_dw[z];
+            }
+        }
+        l1normalize(p_dz_d);
+
+        // update P(w|z) = n(d,w) * P(z|d,w) + alpha * P(w|z)^(n-1)
+        final double[] sums = new double[_K];
+        for (Map.Entry<String, float[]> e : _p_zw.entrySet()) {
+            String w = e.getKey();
+            final float[] p_zw_w = e.getValue();
+
+            Float w_value = doc.get(w);
+            if (w_value != null) { // all words in the document
+                final float n = w_value.floatValue();
+                final float[] p_dwz_dw = p_dwz_d.get(w);
+
+                for (int z = 0; z < _K; z++) {
+                    p_zw_w[z] = n * p_dwz_dw[z] + _alpha * p_zw_w[z];
+                }
+            }
+
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        // normalize to ensure \sum_w P(w|z) = 1
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int z = 0; z < _K; z++) {
+                p_zw_w[z] /= sums[z];
+            }
+        }
+    }
+
+    private boolean isPdzConverged(@Nonnegative final int d, @Nonnull final List<float[]> pPrev_dz,
+            @Nonnull final List<float[]> p_dz) {
+        final float[] pPrev_dz_d = pPrev_dz.get(d);
+        final float[] p_dz_d = p_dz.get(d);
+
+        double diff = 0.d;
+        for (int z = 0; z < _K; z++) {
+            diff += Math.abs(pPrev_dz_d[z] - p_dz_d[z]);
+        }
+        return (diff / _K) < _delta;
+    }
+
+    public float computePerplexity() {
+        double numer = 0.d;
+        double denom = 0.d;
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            final float[] p_dz_d = _p_dz.get(d);
+            for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) {
+                String w = e.getKey();
+                float w_value = e.getValue().floatValue();
+
+                final float[] p_zw_w = _p_zw.get(w);
+                double p_dw = 0.d;
+                for (int z = 0; z < _K; z++) {
+                    p_dw += (double) p_zw_w[z] * p_dz_d[z];
+                }
+
+                numer += w_value * Math.log(p_dw);
+                denom += w_value;
+            }
+        }
+
+        return (float) Math.exp(-1.d * (numer / denom));
+    }
+
+    @Nonnull
+    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+        final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
+            Collections.reverseOrder());
+
+        for (Map.Entry<String, float[]> e : _p_zw.entrySet()) {
+            final String w = e.getKey();
+            final float prob = e.getValue()[z];
+
+            List<String> words = res.get(prob);
+            if (words == null) {
+                words = new ArrayList<String>();
+                res.put(prob, words);
+            }
+            words.add(w);
+        }
+
+        return res;
+    }
+
+    @Nonnull
+    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+        train(new String[][] {doc});
+        return _p_dz.get(0);
+    }
+
+    public float getProbability(@Nonnull final String w, @Nonnegative final int z) {
+        return _p_zw.get(w)[z];
+    }
+
+    public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+        float[] prob_label = _p_zw.get(w);
+        if (prob_label == null) {
+            prob_label = newRandomFloatArray(_K, _rnd);
+            _p_zw.put(w, prob_label);
+        }
+        prob_label[z] = prob;
+
+        // ensure \sum_w P(w|z) = 1
+        final double[] sums = new double[_K];
+        for (float[] p_zw_w : _p_zw.values()) {
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int zi = 0; zi < _K; zi++) {
+                p_zw_w[zi] /= sums[zi];
+            }
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 811af2e..a4076b6 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -112,7 +112,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         private PrimitiveObjectInspector lambdaOI;
 
         // Hyperparameters
-        private int topic;
+        private int topics;
         private float alpha;
         private double delta;
 
@@ -134,7 +134,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
 
         protected Options getOptions() {
             Options opts = new Options();
-            opts.addOption("k", "topic", true, "The number of topics [required]");
+            opts.addOption("k", "topics", true, "The number of topics [required]");
             opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
             opts.addOption("delta", true,
                 "Check convergence in the expectation step [default: 1E-5]");
@@ -176,19 +176,19 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             CommandLine cl = null;
 
             if (argOIs.length != 5) {
-                throw new UDFArgumentException("At least 1 option `-topic` MUST be specified");
+                throw new UDFArgumentException("At least 1 option `-topics` MUST be specified");
             }
 
             String rawArgs = HiveUtils.getConstString(argOIs[4]);
             cl = parseOptions(rawArgs);
 
-            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0);
-            if (topic < 1) {
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 0);
+            if (topics < 1) {
                 throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topic`: " + topic);
+                    "A positive integer MUST be set to an option `-topics`: " + topics);
             }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
             this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d);
 
             return cl;
@@ -211,7 +211,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
-                this.topicOptionField = soi.getStructFieldRef("topic");
+                this.topicOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -253,7 +253,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                 ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
 
-            fieldNames.add("topic");
+            fieldNames.add("topics");
             fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
 
             fieldNames.add("alpha");
@@ -278,7 +278,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 throws HiveException {
             OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
             myAggr.reset();
-            myAggr.setOptions(topic, alpha, delta);
+            myAggr.setOptions(topics, alpha, delta);
         }
 
         @Override
@@ -359,7 +359,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
 
             // restore options from partial result
             Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
-            this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
             this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
@@ -368,7 +368,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
 
             OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
-            myAggr.setOptions(topic, alpha, delta);
+            myAggr.setOptions(topics, alpha, delta);
             myAggr.merge(wcList, lambdaMap);
         }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 9aa15e2..1e28a30 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -63,7 +63,7 @@ public class LDAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(LDAUDTF.class);
 
     // Options
-    protected int topic;
+    protected int topics;
     protected float alpha;
     protected float eta;
     protected long numDocs;
@@ -93,9 +93,9 @@ public class LDAUDTF extends UDTFWithOptions {
     protected ByteBuffer inputBuf;
 
     public LDAUDTF() {
-        this.topic = 10;
-        this.alpha = 1.f / topic;
-        this.eta = 1.f / topic;
+        this.topics = 10;
+        this.alpha = 1.f / topics;
+        this.eta = 1.f / topics;
         this.numDocs = -1L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
@@ -108,7 +108,7 @@ public class LDAUDTF extends UDTFWithOptions {
     @Override
     protected Options getOptions() {
         Options opts = new Options();
-        opts.addOption("k", "topic", true, "The number of topics [default: 10]");
+        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
         opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
         opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
         opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]");
@@ -131,9 +131,9 @@ public class LDAUDTF extends UDTFWithOptions {
         if (argOIs.length >= 2) {
             String rawArgs = HiveUtils.getConstString(argOIs[1]);
             cl = parseOptions(rawArgs);
-            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10);
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
-            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
+            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics);
             this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
             this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d);
             if (tau0 <= 0.d) {
@@ -187,7 +187,7 @@ public class LDAUDTF extends UDTFWithOptions {
     }
 
     protected void initModel() {
-        this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, kappa, delta);
+        this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
     }
 
     @Override
@@ -527,7 +527,7 @@ public class LDAUDTF extends UDTFWithOptions {
         forwardObjs[1] = word;
         forwardObjs[2] = score;
 
-        for (int k = 0; k < topic; k++) {
+        for (int k = 0; k < topics; k++) {
             topicIdx.set(k);
 
             final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
@@ -541,7 +541,7 @@ public class LDAUDTF extends UDTFWithOptions {
             }
         }
 
-        logger.info("Forwarded topic words each of " + topic + " topics");
+        logger.info("Forwarded topic words each of " + topics + " topics");
     }
 
     /*

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
new file mode 100644
index 0000000..c0b60fc
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.CommandLineUtils;
+import hivemall.utils.lang.Primitives;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "plsa_predict",
+        value = "_FUNC_(string word, float value, int label, float prob[, const string options])"
+                + " - Returns a list which consists of <int label, float prob>")
+public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
+
+    @Override
+    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
+        if (typeInfo.length != 4 && typeInfo.length != 5) {
+            throw new UDFArgumentLengthException(
+                "Expected argument length is 4 or 5 but given argument length was "
+                        + typeInfo.length);
+        }
+
+        if (!HiveUtils.isStringTypeInfo(typeInfo[0])) {
+            throw new UDFArgumentTypeException(0,
+                "String type is expected for the first argument word: " + typeInfo[0].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Number type is expected for the second argument value: "
+                        + typeInfo[1].getTypeName());
+        }
+        if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) {
+            throw new UDFArgumentTypeException(2,
+                "Integer type is expected for the third argument label: "
+                        + typeInfo[2].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+            throw new UDFArgumentTypeException(3,
+                "Number type is expected for the forth argument prob: " + typeInfo[3].getTypeName());
+        }
+
+        if (typeInfo.length == 5) {
+            if (!HiveUtils.isStringTypeInfo(typeInfo[4])) {
+                throw new UDFArgumentTypeException(4,
+                    "String type is expected for the fifth argument prob: "
+                            + typeInfo[4].getTypeName());
+            }
+        }
+
+        return new Evaluator();
+    }
+
+    public static class Evaluator extends GenericUDAFEvaluator {
+
+        // input OI
+        private PrimitiveObjectInspector wordOI;
+        private PrimitiveObjectInspector valueOI;
+        private PrimitiveObjectInspector labelOI;
+        private PrimitiveObjectInspector probOI;
+
+        // Hyperparameters
+        private int topics;
+        private float alpha;
+        private double delta;
+
+        // merge OI
+        private StructObjectInspector internalMergeOI;
+        private StructField wcListField;
+        private StructField probMapField;
+        private StructField topicOptionField;
+        private StructField alphaOptionField;
+        private StructField deltaOptionField;
+        private PrimitiveObjectInspector wcListElemOI;
+        private StandardListObjectInspector wcListOI;
+        private StandardMapObjectInspector probMapOI;
+        private PrimitiveObjectInspector probMapKeyOI;
+        private StandardListObjectInspector probMapValueOI;
+        private PrimitiveObjectInspector probMapValueElemOI;
+
+        public Evaluator() {}
+
+        protected Options getOptions() {
+            Options opts = new Options();
+            opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+            opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
+            opts.addOption("delta", true,
+                "Check convergence in the expectation step [default: 1E-5]");
+            return opts;
+        }
+
+        @Nonnull
+        protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException {
+            String[] args = optionValue.split("\\s+");
+            Options opts = getOptions();
+            opts.addOption("help", false, "Show function help");
+            CommandLine cl = CommandLineUtils.parseOptions(args, opts);
+
+            if (cl.hasOption("help")) {
+                Description funcDesc = getClass().getAnnotation(Description.class);
+                final String cmdLineSyntax;
+                if (funcDesc == null) {
+                    cmdLineSyntax = getClass().getSimpleName();
+                } else {
+                    String funcName = funcDesc.name();
+                    cmdLineSyntax = funcName == null ? getClass().getSimpleName()
+                            : funcDesc.value().replace("_FUNC_", funcDesc.name());
+                }
+                StringWriter sw = new StringWriter();
+                sw.write('\n');
+                PrintWriter pw = new PrintWriter(sw);
+                HelpFormatter formatter = new HelpFormatter();
+                formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts,
+                    HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true);
+                pw.flush();
+                String helpMsg = sw.toString();
+                throw new UDFArgumentException(helpMsg);
+            }
+
+            return cl;
+        }
+
+        @Nullable
+        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+            if (argOIs.length != 5) {
+                return null;
+            }
+
+            String rawArgs = HiveUtils.getConstString(argOIs[4]);
+            CommandLine cl = parseOptions(rawArgs);
+
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS);
+            if (topics < 1) {
+                throw new UDFArgumentException(
+                    "A positive integer MUST be set to an option `-topics`: " + topics);
+            }
+
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA);
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+
+            return cl;
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+            assert (parameters.length == 4 || parameters.length == 5);
+            super.init(mode, parameters);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+                processOptions(parameters);
+                this.wordOI = HiveUtils.asStringOI(parameters[0]);
+                this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]);
+                this.labelOI = HiveUtils.asIntegerOI(parameters[2]);
+                this.probOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+            } else {// from partial aggregation
+                StructObjectInspector soi = (StructObjectInspector) parameters[0];
+                this.internalMergeOI = soi;
+                this.wcListField = soi.getStructFieldRef("wcList");
+                this.probMapField = soi.getStructFieldRef("probMap");
+                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.alphaOptionField = soi.getStructFieldRef("alpha");
+                this.deltaOptionField = soi.getStructFieldRef("delta");
+                this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI);
+                this.probMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.probMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.probMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(probMapValueElemOI);
+                this.probMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(probMapKeyOI,
+                    probMapValueOI);
+            }
+
+            // initialize output
+            final ObjectInspector outputOI;
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+                outputOI = internalMergeOI();
+            } else {
+                final ArrayList<String> fieldNames = new ArrayList<String>();
+                final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+                fieldNames.add("label");
+                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+                fieldNames.add("probability");
+                fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+                outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(
+                    fieldNames, fieldOIs));
+            }
+            return outputOI;
+        }
+
+        private static StructObjectInspector internalMergeOI() {
+            ArrayList<String> fieldNames = new ArrayList<String>();
+            ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+            fieldNames.add("wcList");
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+            fieldNames.add("probMap");
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+            fieldNames.add("topics");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+            fieldNames.add("alpha");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+            fieldNames.add("delta");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+        }
+
+        @SuppressWarnings("deprecation")
+        @Override
+        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+            AggregationBuffer myAggr = new PLSAPredictAggregationBuffer();
+            reset(myAggr);
+            return myAggr;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            myAggr.reset();
+            myAggr.setOptions(topics, alpha, delta);
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+                Object[] parameters) throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+
+            if (parameters[0] == null || parameters[1] == null || parameters[2] == null
+                    || parameters[3] == null) {
+                return;
+            }
+
+            String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI);
+            float value = PrimitiveObjectInspectorUtils.getFloat(parameters[1], valueOI);
+            int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI);
+            float prob = PrimitiveObjectInspectorUtils.getFloat(parameters[3], probOI);
+
+            myAggr.iterate(word, value, label, prob);
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            if (myAggr.wcList.size() == 0) {
+                return null;
+            }
+
+            Object[] partialResult = new Object[5];
+            partialResult[0] = myAggr.wcList;
+            partialResult[1] = myAggr.probMap;
+            partialResult[2] = new IntWritable(myAggr.topics);
+            partialResult[3] = new FloatWritable(myAggr.alpha);
+            partialResult[4] = new DoubleWritable(myAggr.delta);
+
+            return partialResult;
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+                throws HiveException {
+            if (partial == null) {
+                return;
+            }
+
+            Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField);
+
+            List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj));
+
+            // fix list elements to Java String objects
+            int wcListSize = wcListRaw.size();
+            List<String> wcList = new ArrayList<String>();
+            for (int i = 0; i < wcListSize; i++) {
+                wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI));
+            }
+
+            Object probMapObj = internalMergeOI.getStructFieldData(partial, probMapField);
+            Map<?, ?> probMapRaw = probMapOI.getMap(HiveUtils.castLazyBinaryObject(probMapObj));
+
+            Map<String, List<Float>> probMap = new HashMap<String, List<Float>>();
+            for (Map.Entry<?, ?> e : probMapRaw.entrySet()) {
+                // fix map keys to Java String objects
+                String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), probMapKeyOI);
+
+                Object probMapValueObj = e.getValue();
+                List<?> probMapValueRaw = probMapValueOI.getList(HiveUtils.castLazyBinaryObject(probMapValueObj));
+
+                // fix map values to lists of Java Float objects
+                int probMapValueSize = probMapValueRaw.size();
+                List<Float> prob_word = new ArrayList<Float>();
+                for (int i = 0; i < probMapValueSize; i++) {
+                    prob_word.add(HiveUtils.getFloat(probMapValueRaw.get(i), probMapValueElemOI));
+                }
+
+                probMap.put(word, prob_word);
+            }
+
+            // restore options from partial result
+            Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+
+            Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
+            this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
+
+            Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField);
+            this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
+
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            myAggr.setOptions(topics, alpha, delta);
+            myAggr.merge(wcList, probMap);
+        }
+
+        @Override
+        public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            float[] topicDistr = myAggr.get();
+
+            SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>(
+                Collections.reverseOrder());
+            for (int i = 0; i < topicDistr.length; i++) {
+                sortedDistr.put(topicDistr[i], i);
+            }
+
+            List<Object[]> result = new ArrayList<Object[]>();
+            for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+                Object[] struct = new Object[2];
+                struct[0] = new IntWritable(e.getValue().intValue()); // label
+                struct[1] = new FloatWritable(e.getKey().floatValue()); // probability
+                result.add(struct);
+            }
+            return result;
+        }
+
+    }
+
+    public static class PLSAPredictAggregationBuffer extends
+            GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+        private List<String> wcList;
+        private Map<String, List<Float>> probMap;
+
+        private int topics;
+        private float alpha;
+        private double delta;
+
+        PLSAPredictAggregationBuffer() {
+            super();
+        }
+
+        void setOptions(int topics, float alpha, double delta) {
+            this.topics = topics;
+            this.alpha = alpha;
+            this.delta = delta;
+        }
+
+        void reset() {
+            this.wcList = new ArrayList<String>();
+            this.probMap = new HashMap<String, List<Float>>();
+        }
+
+        void iterate(@Nonnull final String word, final float value, final int label,
+                final float prob) {
+            wcList.add(word + ":" + value);
+
+            // for an unforeseen word, initialize its probs w/ -1s
+            List<Float> prob_word = probMap.get(word);
+
+            if (prob_word == null) {
+                prob_word = new ArrayList<Float>(Collections.nCopies(topics, -1.f));
+                probMap.put(word, prob_word);
+            }
+
+            // set the given prob value
+            prob_word.set(label, Float.valueOf(prob));
+        }
+
+        void merge(@Nonnull final List<String> o_wcList,
+                @Nonnull final Map<String, List<Float>> o_probMap) {
+            wcList.addAll(o_wcList);
+
+            for (Map.Entry<String, List<Float>> e : o_probMap.entrySet()) {
+                String o_word = e.getKey();
+                List<Float> o_prob_word = e.getValue();
+
+                final List<Float> prob_word = probMap.get(o_word);
+                if (prob_word == null) {// for a partially observed word
+                    probMap.put(o_word, o_prob_word);
+                } else { // for an unforeseen word
+                    for (int k = 0; k < topics; k++) {
+                        final float prob_k = o_prob_word.get(k).floatValue();
+                        if (prob_k != -1.f) { // not default value
+                            prob_word.set(k, prob_k); // set the partial prob value
+                        }
+                    }
+                    probMap.put(o_word, prob_word);
+                }
+            }
+        }
+
+        float[] get() {
+            final IncrementalPLSAModel model = new IncrementalPLSAModel(topics, alpha, delta);
+
+            for (Map.Entry<String, List<Float>> e : probMap.entrySet()) {
+                final String word = e.getKey();
+                final List<Float> prob_word = e.getValue();
+                for (int k = 0; k < topics; k++) {
+                    final float prob_k = prob_word.get(k).floatValue();
+                    if (prob_k != -1.f) {
+                        model.setProbability(word, k, prob_k);
+                    }
+                }
+            }
+
+            String[] wcArray = wcList.toArray(new String[wcList.size()]);
+            return model.getTopicDistribution(wcArray);
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
new file mode 100644
index 0000000..2616133
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -0,0 +1,535 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
+        + " - Returns a relation consists of <int topic, string word, float score>")
+public class PLSAUDTF extends UDTFWithOptions {
+    private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+    
+    public static final int DEFAULT_TOPICS = 10;
+    public static final float DEFAULT_ALPHA = 0.5f;
+    public static final double DEFAULT_DELTA = 1E-3d;
+    
+    // Options
+    protected int topics;
+    protected float alpha;
+    protected int iterations;
+    protected double delta;
+    protected double eps;
+    protected int miniBatchSize;
+
+    // number of proceeded training samples
+    protected long count;
+
+    protected String[][] miniBatch;
+    protected int miniBatchCount;
+
+    protected transient IncrementalPLSAModel model;
+
+    protected ListObjectInspector wordCountsOI;
+
+    // for iterations
+    protected NioStatefullSegment fileIO;
+    protected ByteBuffer inputBuf;
+
+    public PLSAUDTF() {
+        this.topics = DEFAULT_TOPICS;
+        this.alpha = DEFAULT_ALPHA;
+        this.iterations = 10;
+        this.delta = DEFAULT_DELTA;
+        this.eps = 1E-1d;
+        this.miniBatchSize = 128;
+    }
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+        opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
+        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
+        opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
+        opts.addOption("eps", "epsilon", true,
+            "Check convergence based on the difference of perplexity [default: 1E-1]");
+        opts.addOption("s", "mini_batch_size", true,
+            "Repeat model updating per mini-batch [default: 128]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+        CommandLine cl = null;
+
+        if (argOIs.length >= 2) {
+            String rawArgs = HiveUtils.getConstString(argOIs[1]);
+            cl = parseOptions(rawArgs);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA);
+            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+            if (iterations < 1) {
+                throw new UDFArgumentException(
+                    "'-iterations' must be greater than or equals to 1: " + iterations);
+            }
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
+            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
+            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+        }
+
+        return cl;
+    }
+
+    @Override
+    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length < 1) {
+            throw new UDFArgumentException(
+                "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
+        }
+
+        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+        processOptions(argOIs);
+
+        this.model = null;
+        this.count = 0L;
+        this.miniBatch = new String[miniBatchSize][];
+        this.miniBatchCount = 0;
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        fieldNames.add("topic");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+        fieldNames.add("word");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+        fieldNames.add("score");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+    }
+
+    protected void initModel() {
+        this.model = new IncrementalPLSAModel(topics, alpha, delta);
+    }
+
+    @Override
+    public void process(Object[] args) throws HiveException {
+        if (model == null) {
+            initModel();
+        }
+
+        int length = wordCountsOI.getListLength(args[0]);
+        String[] wordCounts = new String[length];
+        int j = 0;
+        for (int i = 0; i < length; i++) {
+            Object o = wordCountsOI.getListElement(args[0], i);
+            if (o == null) {
+                throw new HiveException("Given feature vector contains invalid elements");
+            }
+            String s = o.toString();
+            wordCounts[j] = s;
+            j++;
+        }
+        if (j == 0) {// avoid empty documents
+            return;
+        }
+
+        count++;
+
+        recordTrainSampleToTempFile(wordCounts);
+
+        miniBatch[miniBatchCount] = wordCounts;
+        miniBatchCount++;
+
+        if (miniBatchCount == miniBatchSize) {
+            model.train(miniBatch);
+            Arrays.fill(miniBatch, null); // clear
+            miniBatchCount = 0;
+        }
+    }
+
+    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
+            throws HiveException {
+        if (iterations == 1) {
+            return;
+        }
+
+        ByteBuffer buf = inputBuf;
+        NioStatefullSegment dst = fileIO;
+
+        if (buf == null) {
+            final File file;
+            try {
+                file = File.createTempFile("hivemall_plsa", ".sgmt");
+                file.deleteOnExit();
+                if (!file.canWrite()) {
+                    throw new UDFArgumentException("Cannot write a temporary file: "
+                            + file.getAbsolutePath());
+                }
+                logger.info("Record training samples to a file: " + file.getAbsolutePath());
+            } catch (IOException ioe) {
+                throw new UDFArgumentException(ioe);
+            } catch (Throwable e) {
+                throw new UDFArgumentException(e);
+            }
+            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
+            this.fileIO = dst = new NioStatefullSegment(file, false);
+        }
+
+        int wcLength = 0;
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            wcLength += wc.getBytes().length;
+        }
+        // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
+        int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength;
+        int remain = buf.remaining();
+        if (remain < recordBytes) {
+            writeBuffer(buf, dst);
+        }
+
+        buf.putInt(recordBytes);
+        buf.putInt(wordCounts.length);
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            buf.putInt(wc.length());
+            buf.put(wc.getBytes());
+        }
+    }
+
+    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
+            throws HiveException {
+        srcBuf.flip();
+        try {
+            dst.write(srcBuf);
+        } catch (IOException e) {
+            throw new HiveException("Exception causes while writing a buffer to file", e);
+        }
+        srcBuf.clear();
+    }
+
+    @Override
+    public void close() throws HiveException {
+        if (count == 0) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+        forwardModel();
+        this.model = null;
+    }
+
+    protected final void runIterativeTraining(@Nonnegative final int iterations)
+            throws HiveException {
+        final ByteBuffer buf = this.inputBuf;
+        final NioStatefullSegment dst = this.fileIO;
+        assert (buf != null);
+        assert (dst != null);
+        final long numTrainingExamples = count;
+
+        final Reporter reporter = getReporter();
+        final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+            "hivemall.plsa.IncrementalPLSA$Counter", "iteration");
+
+        try {
+            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+                if (buf.position() == 0) {
+                    return; // no training example
+                }
+                buf.flip();
+
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    reportProgress(reporter);
+                    setCounterValue(iterCounter, iter);
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    while (buf.remaining() > 0) {
+                        int recordBytes = buf.getInt();
+                        assert (recordBytes > 0) : recordBytes;
+                        int wcLength = buf.getInt();
+                        final String[] wordCounts = new String[wcLength];
+                        for (int j = 0; j < wcLength; j++) {
+                            int len = buf.getInt();
+                            byte[] bytes = new byte[len];
+                            buf.get(bytes);
+                            wordCounts[j] = new String(bytes);
+                        }
+
+                        miniBatch[miniBatchCount] = wordCounts;
+                        miniBatchCount++;
+
+                        if (miniBatchCount == miniBatchSize) {
+                            model.train(miniBatch);
+                            perplexity += model.computePerplexity();
+                            numTrain++;
+
+                            Arrays.fill(miniBatch, null); // clear
+                            miniBatchCount = 0;
+                        }
+                    }
+                    buf.rewind();
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on memory (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total) ");
+            } else {// read training examples in the temporary file and invoke train for each example
+
+                // write training examples in buffer to a temporary file
+                if (buf.remaining() > 0) {
+                    writeBuffer(buf, dst);
+                }
+                try {
+                    dst.flush();
+                } catch (IOException e) {
+                    throw new HiveException("Failed to flush a file: "
+                            + dst.getFile().getAbsolutePath(), e);
+                }
+                if (logger.isInfoEnabled()) {
+                    File tmpFile = dst.getFile();
+                    logger.info("Wrote " + numTrainingExamples
+                            + " records to a temporary file for iterative training: "
+                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+                            + ")");
+                }
+
+                // run iterations
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    setCounterValue(iterCounter, iter);
+
+                    buf.clear();
+                    dst.resetPosition();
+                    while (true) {
+                        reportProgress(reporter);
+                        // TODO prefetch
+                        // writes training examples to a buffer in the temporary file
+                        final int bytesRead;
+                        try {
+                            bytesRead = dst.read(buf);
+                        } catch (IOException e) {
+                            throw new HiveException("Failed to read a file: "
+                                    + dst.getFile().getAbsolutePath(), e);
+                        }
+                        if (bytesRead == 0) { // reached file EOF
+                            break;
+                        }
+                        assert (bytesRead > 0) : bytesRead;
+
+                        // reads training examples from a buffer
+                        buf.flip();
+                        int remain = buf.remaining();
+                        if (remain < SizeOf.INT) {
+                            throw new HiveException("Illegal file format was detected");
+                        }
+                        while (remain >= SizeOf.INT) {
+                            int pos = buf.position();
+                            int recordBytes = buf.getInt();
+                            remain -= SizeOf.INT;
+                            if (remain < recordBytes) {
+                                buf.position(pos);
+                                break;
+                            }
+
+                            int wcLength = buf.getInt();
+                            final String[] wordCounts = new String[wcLength];
+                            for (int j = 0; j < wcLength; j++) {
+                                int len = buf.getInt();
+                                byte[] bytes = new byte[len];
+                                buf.get(bytes);
+                                wordCounts[j] = new String(bytes);
+                            }
+
+                            miniBatch[miniBatchCount] = wordCounts;
+                            miniBatchCount++;
+
+                            if (miniBatchCount == miniBatchSize) {
+                                model.train(miniBatch);
+                                perplexity += model.computePerplexity();
+                                numTrain++;
+
+                                Arrays.fill(miniBatch, null); // clear
+                                miniBatchCount = 0;
+                            }
+
+                            remain -= recordBytes;
+                        }
+                        buf.compact();
+                    }
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on a secondary storage (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total)");
+            }
+        } finally {
+            // delete the temporary file and release resources
+            try {
+                dst.close(true);
+            } catch (IOException e) {
+                throw new HiveException("Failed to close a file: "
+                        + dst.getFile().getAbsolutePath(), e);
+            }
+            this.inputBuf = null;
+            this.fileIO = null;
+        }
+    }
+
+    protected void forwardModel() throws HiveException {
+        final IntWritable topicIdx = new IntWritable();
+        final Text word = new Text();
+        final FloatWritable score = new FloatWritable();
+
+        final Object[] forwardObjs = new Object[3];
+        forwardObjs[0] = topicIdx;
+        forwardObjs[1] = word;
+        forwardObjs[2] = score;
+
+        for (int k = 0; k < topics; k++) {
+            topicIdx.set(k);
+
+            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
+            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+                score.set(e.getKey());
+                List<String> words = e.getValue();
+                for (int i = 0; i < words.size(); i++) {
+                    word.set(words.get(i));
+                    forward(forwardObjs);
+                }
+            }
+        }
+
+        logger.info("Forwarded topic words each of " + topics + " topics");
+    }
+
+    /*
+     * For testing:
+     */
+
+    @VisibleForTesting
+    double getProbability(String label, int k) {
+        return model.getProbability(label, k);
+    }
+
+    @VisibleForTesting
+    SortedMap<Float, List<String>> getTopicWords(int k) {
+        return model.getTopicWords(k);
+    }
+
+    @VisibleForTesting
+    float[] getTopicDistribution(@Nonnull String[] doc) {
+        return model.getTopicDistribution(doc);
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index c20c363..4177d70 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -735,4 +735,14 @@ public final class ArrayUtils {
         return ret;
     }
 
+    @Nonnull
+    public static float[] newRandomFloatArray(@Nonnegative final int size,
+            @Nonnull final Random rnd) {
+        final float[] ret = new float[size];
+        for (int i = 0; i < size; i++) {
+            ret[i] = rnd.nextFloat();
+        }
+        return ret;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 061b75d..8ffb89c 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -408,4 +408,16 @@ public final class MathUtils {
         return Math.log(logsumexp) + max;
     }
 
+    @Nonnull
+    public static float[] l1normalize(@Nonnull final float[] arr) {
+        double sum = 0.d;
+        for (int i = 0; i < arr.length; i++) {
+            sum += Math.abs(arr[i]);
+        }
+        for (int i = 0; i < arr.length; i++) {
+            arr[i] /= sum;
+        }
+        return arr;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
new file mode 100644
index 0000000..db34a38
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
+
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+
+public class IncrementalPLSAModelTest {
+    private static final boolean DEBUG = false;
+
+    @Test
+    public void testOnline() {
+        int K = 2;
+        int it = 0;
+        int maxIter = 1024;
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
+                "oranges:1"};
+
+        do {
+            perplexityPrev = perplexity;
+            perplexity = 0.f;
+
+            // online (i.e., one-by-one) updating
+            model.train(new String[][] {doc1});
+            perplexity += model.computePerplexity();
+
+            model.train(new String[][] {doc2});
+            perplexity += model.computePerplexity();
+
+            perplexity /= 2.f; // mean perplexity for the 2 docs
+
+            it++;
+            println("Iteration " + it + ": mean perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f);
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = model.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = model.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+
+        int k1, k2;
+        float[] topicDistr = model.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+                + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+                + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+    }
+
+    @Test
+    public void testMiniBatch() {
+        int K = 2;
+        int it = 0;
+        int maxIter = 2048;
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
+                "oranges:1"};
+
+        do {
+            perplexityPrev = perplexity;
+
+            model.train(new String[][] {doc1, doc2});
+            perplexity = model.computePerplexity();
+
+            it++;
+            println("Iteration " + it + ": perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f);
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = model.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = model.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+
+        int k1, k2;
+        float[] topicDistr = model.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+                + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+                + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+    }
+
+    @Test
+    public void testNews20() throws IOException {
+        int K = 20;
+        int miniBatchSize = 2;
+
+        int cnt, it;
+        int maxIter = 64;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.8f, 1E-5d);
+
+        BufferedReader news20 = readFile("news20-multiclass.gz");
+
+        String[][] docs = new String[K][];
+
+        String line = news20.readLine();
+        List<String> doc = new ArrayList<String>();
+
+        cnt = 0;
+        while (line != null) {
+            StringTokenizer tokens = new StringTokenizer(line, " ");
+
+            int k = Integer.parseInt(tokens.nextToken()) - 1;
+
+            while (tokens.hasMoreTokens()) {
+                doc.add(tokens.nextToken());
+            }
+
+            // store first document in each of K classes
+            if (docs[k] == null) {
+                docs[k] = doc.toArray(new String[doc.size()]);
+                cnt++;
+            }
+
+            if (cnt == K) {
+                break;
+            }
+
+            doc.clear();
+            line = news20.readLine();
+        }
+        println("Stored " + cnt + " docs. Start training w/ mini-batch size: " + miniBatchSize);
+
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        it = 0;
+        do {
+            perplexityPrev = perplexity;
+            perplexity = 0.f;
+
+            int head = 0;
+            cnt = 0;
+            while (head < K) {
+                int tail = head + miniBatchSize;
+                model.train(Arrays.copyOfRange(docs, head, tail));
+                perplexity += model.computePerplexity();
+                head = tail;
+                cnt++;
+                println("Processed mini-batch#" + cnt);
+            }
+
+            perplexity /= cnt;
+
+            it++;
+            println("Iteration " + it + ": mean perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-3f);
+
+        Set<Integer> topics = new HashSet<Integer>();
+        for (int k = 0; k < K; k++) {
+            topics.add(findMaxTopic(model.getTopicDistribution(docs[k])));
+        }
+
+        int n = topics.size();
+        println("# of unique topics: " + n);
+        Assert.assertTrue("At least 15 documents SHOULD be classified to different topics, "
+                + "but there are only " + n + " unique topics.", n >= 15);
+    }
+
+    private static void println(String msg) {
+        if (DEBUG) {
+            System.out.println(msg);
+        }
+    }
+
+    @Nonnull
+    private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+        // use data stored for KPA UDTF test
+        InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName);
+        if (fileName.endsWith(".gz")) {
+            is = new GZIPInputStream(is);
+        }
+        return new BufferedReader(new InputStreamReader(is));
+    }
+
+    @Nonnull
+    private static int findMaxTopic(@Nonnull float[] topicDistr) {
+        int maxIdx = 0;
+        for (int i = 1; i < topicDistr.length; i++) {
+            if (topicDistr[maxIdx] < topicDistr[i]) {
+                maxIdx = i;
+            }
+        }
+        return maxIdx;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index a23d917..2c08560 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -100,7 +100,7 @@ public class LDAPredictUDAFTest {
                 PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
                         PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")};
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
 
         evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 
@@ -117,7 +117,7 @@ public class LDAPredictUDAFTest {
                 ObjectInspectorFactory.getStandardListObjectInspector(
                         PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
 
-        fieldNames.add("topic");
+        fieldNames.add("topics");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
 
         fieldNames.add("alpha");

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index d1e3f81..a5881d4 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -42,7 +42,7 @@ public class LDAUDTFTest {
         ObjectInspector[] argOIs = new ObjectInspector[] {
             ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
             ObjectInspectorUtils.getConstantObjectInspector(
-                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2 -s 1")};
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2 -num_docs 2 -s 1")};
 
         udtf.initialize(argOIs);