You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/27 13:45:44 UTC

[2/2] incubator-hivemall git commit: Close #71: [HIVEMALL-74] Implement pLSA

Close #71: [HIVEMALL-74] Implement pLSA


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f2bf3a72
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f2bf3a72
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f2bf3a72

Branch: refs/heads/master
Commit: f2bf3a72b2f8deb0835feed649369c885a23053c
Parents: bffd2c7
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Thu Apr 27 22:44:44 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 27 22:44:44 2017 +0900

----------------------------------------------------------------------
 .../topicmodel/IncrementalPLSAModel.java        | 316 +++++++++++
 .../hivemall/topicmodel/LDAPredictUDAF.java     |  24 +-
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  22 +-
 .../hivemall/topicmodel/PLSAPredictUDAF.java    | 480 +++++++++++++++++
 .../main/java/hivemall/topicmodel/PLSAUDTF.java | 535 +++++++++++++++++++
 .../java/hivemall/utils/lang/ArrayUtils.java    |  10 +
 .../java/hivemall/utils/math/MathUtils.java     |  12 +
 .../topicmodel/IncrementalPLSAModelTest.java    | 291 ++++++++++
 .../hivemall/topicmodel/LDAPredictUDAFTest.java |   4 +-
 .../java/hivemall/topicmodel/LDAUDTFTest.java   |   2 +-
 .../topicmodel/PLSAPredictUDAFTest.java         | 217 ++++++++
 .../java/hivemall/topicmodel/PLSAUDTFTest.java  | 106 ++++
 docs/gitbook/SUMMARY.md                         |   1 +
 docs/gitbook/clustering/plsa.md                 | 154 ++++++
 resources/ddl/define-all-as-permanent.hive      |   6 +
 resources/ddl/define-all.hive                   |   6 +
 resources/ddl/define-all.spark                  |   6 +
 resources/ddl/define-udfs.td.hql                |   2 +
 18 files changed, 2168 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
new file mode 100644
index 0000000..745e510
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
+import static hivemall.utils.math.MathUtils.l1normalize;
+import hivemall.model.FeatureValue;
+import hivemall.utils.math.MathUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class IncrementalPLSAModel {
+
+    // ---------------------------------
+    // HyperParameters
+
+    // number of topics
+    private final int _K;
+
+    // control how much P(w|z) update is affected by the last value
+    private final float _alpha;
+
+    // check convergence of P(w|z) for a document
+    private final double _delta;
+
+    // ---------------------------------
+
+    // random number generator
+    @Nonnull
+    private final Random _rnd;
+
+    // optimized in the E step
+    private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
+
+    // optimized in the M step
+    @Nonnull
+    private List<float[]> _p_dz; // P(z|d) probability of topics for documents
+    private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+    @Nonnull
+    private final List<Map<String, Float>> _miniBatchDocs;
+    private int _miniBatchSize;
+
+    public IncrementalPLSAModel(int K, float alpha, double delta) {
+        this._K = K;
+        this._alpha = alpha;
+        this._delta = delta;
+
+        this._rnd = new Random(1001);
+
+        this._p_zw = new HashMap<String, float[]>();
+
+        this._miniBatchDocs = new ArrayList<Map<String, Float>>();
+    }
+
+    public void train(@Nonnull final String[][] miniBatch) {
+        initMiniBatch(miniBatch, _miniBatchDocs);
+
+        this._miniBatchSize = _miniBatchDocs.size();
+
+        initParams();
+
+        final List<float[]> pPrev_dz = new ArrayList<float[]>();
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            do {
+                pPrev_dz.clear();
+                pPrev_dz.addAll(_p_dz);
+
+                // Expectation
+                eStep(d);
+
+                // Maximization
+                mStep(d);
+            } while (!isPdzConverged(d, pPrev_dz, _p_dz)); // until get stable value of P(z|d)
+        }
+    }
+
+    private static void initMiniBatch(@Nonnull final String[][] miniBatch,
+            @Nonnull final List<Map<String, Float>> docs) {
+        docs.clear();
+
+        final FeatureValue probe = new FeatureValue();
+
+        // parse document
+        for (final String[] e : miniBatch) {
+            if (e == null || e.length == 0) {
+                continue;
+            }
+
+            final Map<String, Float> doc = new HashMap<String, Float>();
+
+            // parse features
+            for (String fv : e) {
+                if (fv == null) {
+                    continue;
+                }
+                FeatureValue.parseFeatureAsString(fv, probe);
+                String word = probe.getFeatureAsString();
+                float value = probe.getValueAsFloat();
+                doc.put(word, Float.valueOf(value));
+            }
+
+            docs.add(doc);
+        }
+    }
+
+    private void initParams() {
+        final List<float[]> p_dz = new ArrayList<float[]>();
+        final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>();
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            // init P(z|d)
+            float[] p_dz_d = l1normalize(newRandomFloatArray(_K, _rnd));
+            p_dz.add(p_dz_d);
+
+            final Map<String, float[]> p_dwz_d = new HashMap<String, float[]>();
+            p_dwz.add(p_dwz_d);
+
+            for (final String w : _miniBatchDocs.get(d).keySet()) {
+                // init P(z|d,w)
+                float[] p_dwz_dw = l1normalize(newRandomFloatArray(_K, _rnd));
+                p_dwz_d.put(w, p_dwz_dw);
+
+                // insert new labels to P(w|z)
+                if (!_p_zw.containsKey(w)) {
+                    _p_zw.put(w, newRandomFloatArray(_K, _rnd));
+                }
+            }
+        }
+
+        // ensure \sum_w P(w|z) = 1
+        final double[] sums = new double[_K];
+        for (float[] p_zw_w : _p_zw.values()) {
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int z = 0; z < _K; z++) {
+                p_zw_w[z] /= sums[z];
+            }
+        }
+
+        this._p_dz = p_dz;
+        this._p_dwz = p_dwz;
+    }
+
+    private void eStep(@Nonnegative final int d) {
+        final Map<String, float[]> p_dwz_d = _p_dwz.get(d);
+        final float[] p_dz_d = _p_dz.get(d);
+
+        // update P(z|d,w) = P(z|d) * P(w|z)
+        for (final String w : _miniBatchDocs.get(d).keySet()) {
+            final float[] p_dwz_dw = p_dwz_d.get(w);
+            final float[] p_zw_w = _p_zw.get(w);
+            for (int z = 0; z < _K; z++) {
+                p_dwz_dw[z] = p_dz_d[z] * p_zw_w[z];
+            }
+            l1normalize(p_dwz_dw);
+        }
+    }
+
+    private void mStep(@Nonnegative final int d) {
+        final Map<String, Float> doc = _miniBatchDocs.get(d);
+        final Map<String, float[]> p_dwz_d = _p_dwz.get(d);
+
+        // update P(z|d) = n(d,w) * P(z|d,w)
+        final float[] p_dz_d = _p_dz.get(d);
+        Arrays.fill(p_dz_d, 0.f); // zero-fill w/ keeping pointer to _p_dz.get(d)
+        for (Map.Entry<String, Float> e : doc.entrySet()) {
+            final float[] p_dwz_dw = p_dwz_d.get(e.getKey());
+            final float n = e.getValue().floatValue();
+            for (int z = 0; z < _K; z++) {
+                p_dz_d[z] += n * p_dwz_dw[z];
+            }
+        }
+        l1normalize(p_dz_d);
+
+        // update P(w|z) = n(d,w) * P(z|d,w) + alpha * P(w|z)^(n-1)
+        final double[] sums = new double[_K];
+        for (Map.Entry<String, float[]> e : _p_zw.entrySet()) {
+            String w = e.getKey();
+            final float[] p_zw_w = e.getValue();
+
+            Float w_value = doc.get(w);
+            if (w_value != null) { // all words in the document
+                final float n = w_value.floatValue();
+                final float[] p_dwz_dw = p_dwz_d.get(w);
+
+                for (int z = 0; z < _K; z++) {
+                    p_zw_w[z] = n * p_dwz_dw[z] + _alpha * p_zw_w[z];
+                }
+            }
+
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        // normalize to ensure \sum_w P(w|z) = 1
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int z = 0; z < _K; z++) {
+                p_zw_w[z] /= sums[z];
+            }
+        }
+    }
+
+    private boolean isPdzConverged(@Nonnegative final int d, @Nonnull final List<float[]> pPrev_dz,
+            @Nonnull final List<float[]> p_dz) {
+        final float[] pPrev_dz_d = pPrev_dz.get(d);
+        final float[] p_dz_d = p_dz.get(d);
+
+        double diff = 0.d;
+        for (int z = 0; z < _K; z++) {
+            diff += Math.abs(pPrev_dz_d[z] - p_dz_d[z]);
+        }
+        return (diff / _K) < _delta;
+    }
+
+    public float computePerplexity() {
+        double numer = 0.d;
+        double denom = 0.d;
+
+        for (int d = 0; d < _miniBatchSize; d++) {
+            final float[] p_dz_d = _p_dz.get(d);
+            for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) {
+                String w = e.getKey();
+                float w_value = e.getValue().floatValue();
+
+                final float[] p_zw_w = _p_zw.get(w);
+                double p_dw = 0.d;
+                for (int z = 0; z < _K; z++) {
+                    p_dw += (double) p_zw_w[z] * p_dz_d[z];
+                }
+
+                numer += w_value * Math.log(p_dw);
+                denom += w_value;
+            }
+        }
+
+        return (float) Math.exp(-1.d * (numer / denom));
+    }
+
+    @Nonnull
+    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+        final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
+            Collections.reverseOrder());
+
+        for (Map.Entry<String, float[]> e : _p_zw.entrySet()) {
+            final String w = e.getKey();
+            final float prob = e.getValue()[z];
+
+            List<String> words = res.get(prob);
+            if (words == null) {
+                words = new ArrayList<String>();
+                res.put(prob, words);
+            }
+            words.add(w);
+        }
+
+        return res;
+    }
+
+    @Nonnull
+    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+        train(new String[][] {doc});
+        return _p_dz.get(0);
+    }
+
+    public float getProbability(@Nonnull final String w, @Nonnegative final int z) {
+        return _p_zw.get(w)[z];
+    }
+
+    public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+        float[] prob_label = _p_zw.get(w);
+        if (prob_label == null) {
+            prob_label = newRandomFloatArray(_K, _rnd);
+            _p_zw.put(w, prob_label);
+        }
+        prob_label[z] = prob;
+
+        // ensure \sum_w P(w|z) = 1
+        final double[] sums = new double[_K];
+        for (float[] p_zw_w : _p_zw.values()) {
+            MathUtils.add(p_zw_w, sums, _K);
+        }
+        for (float[] p_zw_w : _p_zw.values()) {
+            for (int zi = 0; zi < _K; zi++) {
+                p_zw_w[zi] /= sums[zi];
+            }
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 811af2e..a4076b6 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -112,7 +112,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         private PrimitiveObjectInspector lambdaOI;
 
         // Hyperparameters
-        private int topic;
+        private int topics;
         private float alpha;
         private double delta;
 
@@ -134,7 +134,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
 
         protected Options getOptions() {
             Options opts = new Options();
-            opts.addOption("k", "topic", true, "The number of topics [required]");
+            opts.addOption("k", "topics", true, "The number of topics [required]");
             opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
             opts.addOption("delta", true,
                 "Check convergence in the expectation step [default: 1E-5]");
@@ -176,19 +176,19 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             CommandLine cl = null;
 
             if (argOIs.length != 5) {
-                throw new UDFArgumentException("At least 1 option `-topic` MUST be specified");
+                throw new UDFArgumentException("At least 1 option `-topics` MUST be specified");
             }
 
             String rawArgs = HiveUtils.getConstString(argOIs[4]);
             cl = parseOptions(rawArgs);
 
-            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0);
-            if (topic < 1) {
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 0);
+            if (topics < 1) {
                 throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topic`: " + topic);
+                    "A positive integer MUST be set to an option `-topics`: " + topics);
             }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
             this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d);
 
             return cl;
@@ -211,7 +211,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
-                this.topicOptionField = soi.getStructFieldRef("topic");
+                this.topicOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -253,7 +253,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 PrimitiveObjectInspectorFactory.javaStringObjectInspector,
                 ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
 
-            fieldNames.add("topic");
+            fieldNames.add("topics");
             fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
 
             fieldNames.add("alpha");
@@ -278,7 +278,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 throws HiveException {
             OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
             myAggr.reset();
-            myAggr.setOptions(topic, alpha, delta);
+            myAggr.setOptions(topics, alpha, delta);
         }
 
         @Override
@@ -359,7 +359,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
 
             // restore options from partial result
             Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
-            this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
             this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
@@ -368,7 +368,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
 
             OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
-            myAggr.setOptions(topic, alpha, delta);
+            myAggr.setOptions(topics, alpha, delta);
             myAggr.merge(wcList, lambdaMap);
         }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 9aa15e2..1e28a30 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -63,7 +63,7 @@ public class LDAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(LDAUDTF.class);
 
     // Options
-    protected int topic;
+    protected int topics;
     protected float alpha;
     protected float eta;
     protected long numDocs;
@@ -93,9 +93,9 @@ public class LDAUDTF extends UDTFWithOptions {
     protected ByteBuffer inputBuf;
 
     public LDAUDTF() {
-        this.topic = 10;
-        this.alpha = 1.f / topic;
-        this.eta = 1.f / topic;
+        this.topics = 10;
+        this.alpha = 1.f / topics;
+        this.eta = 1.f / topics;
         this.numDocs = -1L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
@@ -108,7 +108,7 @@ public class LDAUDTF extends UDTFWithOptions {
     @Override
     protected Options getOptions() {
         Options opts = new Options();
-        opts.addOption("k", "topic", true, "The number of topics [default: 10]");
+        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
         opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
         opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
         opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]");
@@ -131,9 +131,9 @@ public class LDAUDTF extends UDTFWithOptions {
         if (argOIs.length >= 2) {
             String rawArgs = HiveUtils.getConstString(argOIs[1]);
             cl = parseOptions(rawArgs);
-            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10);
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
-            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
+            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics);
             this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
             this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d);
             if (tau0 <= 0.d) {
@@ -187,7 +187,7 @@ public class LDAUDTF extends UDTFWithOptions {
     }
 
     protected void initModel() {
-        this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, kappa, delta);
+        this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
     }
 
     @Override
@@ -527,7 +527,7 @@ public class LDAUDTF extends UDTFWithOptions {
         forwardObjs[1] = word;
         forwardObjs[2] = score;
 
-        for (int k = 0; k < topic; k++) {
+        for (int k = 0; k < topics; k++) {
             topicIdx.set(k);
 
             final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
@@ -541,7 +541,7 @@ public class LDAUDTF extends UDTFWithOptions {
             }
         }
 
-        logger.info("Forwarded topic words each of " + topic + " topics");
+        logger.info("Forwarded topic words each of " + topics + " topics");
     }
 
     /*

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
new file mode 100644
index 0000000..c0b60fc
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.CommandLineUtils;
+import hivemall.utils.lang.Primitives;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(name = "plsa_predict",
+        value = "_FUNC_(string word, float value, int label, float prob[, const string options])"
+                + " - Returns a list which consists of <int label, float prob>")
+public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
+
+    @Override
+    public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
+        if (typeInfo.length != 4 && typeInfo.length != 5) {
+            throw new UDFArgumentLengthException(
+                "Expected argument length is 4 or 5 but given argument length was "
+                        + typeInfo.length);
+        }
+
+        if (!HiveUtils.isStringTypeInfo(typeInfo[0])) {
+            throw new UDFArgumentTypeException(0,
+                "String type is expected for the first argument word: " + typeInfo[0].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Number type is expected for the second argument value: "
+                        + typeInfo[1].getTypeName());
+        }
+        if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) {
+            throw new UDFArgumentTypeException(2,
+                "Integer type is expected for the third argument label: "
+                        + typeInfo[2].getTypeName());
+        }
+        if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+            throw new UDFArgumentTypeException(3,
+                "Number type is expected for the forth argument prob: " + typeInfo[3].getTypeName());
+        }
+
+        if (typeInfo.length == 5) {
+            if (!HiveUtils.isStringTypeInfo(typeInfo[4])) {
+                throw new UDFArgumentTypeException(4,
+                    "String type is expected for the fifth argument prob: "
+                            + typeInfo[4].getTypeName());
+            }
+        }
+
+        return new Evaluator();
+    }
+
+    public static class Evaluator extends GenericUDAFEvaluator {
+
+        // input OI
+        private PrimitiveObjectInspector wordOI;
+        private PrimitiveObjectInspector valueOI;
+        private PrimitiveObjectInspector labelOI;
+        private PrimitiveObjectInspector probOI;
+
+        // Hyperparameters
+        private int topics;
+        private float alpha;
+        private double delta;
+
+        // merge OI
+        private StructObjectInspector internalMergeOI;
+        private StructField wcListField;
+        private StructField probMapField;
+        private StructField topicOptionField;
+        private StructField alphaOptionField;
+        private StructField deltaOptionField;
+        private PrimitiveObjectInspector wcListElemOI;
+        private StandardListObjectInspector wcListOI;
+        private StandardMapObjectInspector probMapOI;
+        private PrimitiveObjectInspector probMapKeyOI;
+        private StandardListObjectInspector probMapValueOI;
+        private PrimitiveObjectInspector probMapValueElemOI;
+
+        public Evaluator() {}
+
+        protected Options getOptions() {
+            Options opts = new Options();
+            opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+            opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
+            opts.addOption("delta", true,
+                "Check convergence in the expectation step [default: 1E-5]");
+            return opts;
+        }
+
+        @Nonnull
+        protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException {
+            String[] args = optionValue.split("\\s+");
+            Options opts = getOptions();
+            opts.addOption("help", false, "Show function help");
+            CommandLine cl = CommandLineUtils.parseOptions(args, opts);
+
+            if (cl.hasOption("help")) {
+                Description funcDesc = getClass().getAnnotation(Description.class);
+                final String cmdLineSyntax;
+                if (funcDesc == null) {
+                    cmdLineSyntax = getClass().getSimpleName();
+                } else {
+                    String funcName = funcDesc.name();
+                    cmdLineSyntax = funcName == null ? getClass().getSimpleName()
+                            : funcDesc.value().replace("_FUNC_", funcDesc.name());
+                }
+                StringWriter sw = new StringWriter();
+                sw.write('\n');
+                PrintWriter pw = new PrintWriter(sw);
+                HelpFormatter formatter = new HelpFormatter();
+                formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts,
+                    HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true);
+                pw.flush();
+                String helpMsg = sw.toString();
+                throw new UDFArgumentException(helpMsg);
+            }
+
+            return cl;
+        }
+
+        @Nullable
+        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+            if (argOIs.length != 5) {
+                return null;
+            }
+
+            String rawArgs = HiveUtils.getConstString(argOIs[4]);
+            CommandLine cl = parseOptions(rawArgs);
+
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS);
+            if (topics < 1) {
+                throw new UDFArgumentException(
+                    "A positive integer MUST be set to an option `-topics`: " + topics);
+            }
+
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA);
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+
+            return cl;
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+            assert (parameters.length == 4 || parameters.length == 5);
+            super.init(mode, parameters);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+                processOptions(parameters);
+                this.wordOI = HiveUtils.asStringOI(parameters[0]);
+                this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]);
+                this.labelOI = HiveUtils.asIntegerOI(parameters[2]);
+                this.probOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+            } else {// from partial aggregation
+                StructObjectInspector soi = (StructObjectInspector) parameters[0];
+                this.internalMergeOI = soi;
+                this.wcListField = soi.getStructFieldRef("wcList");
+                this.probMapField = soi.getStructFieldRef("probMap");
+                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.alphaOptionField = soi.getStructFieldRef("alpha");
+                this.deltaOptionField = soi.getStructFieldRef("delta");
+                this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI);
+                this.probMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.probMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+                this.probMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(probMapValueElemOI);
+                this.probMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(probMapKeyOI,
+                    probMapValueOI);
+            }
+
+            // initialize output
+            final ObjectInspector outputOI;
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+                outputOI = internalMergeOI();
+            } else {
+                final ArrayList<String> fieldNames = new ArrayList<String>();
+                final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+                fieldNames.add("label");
+                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+                fieldNames.add("probability");
+                fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+                outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(
+                    fieldNames, fieldOIs));
+            }
+            return outputOI;
+        }
+
+        private static StructObjectInspector internalMergeOI() {
+            ArrayList<String> fieldNames = new ArrayList<String>();
+            ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+            fieldNames.add("wcList");
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
+
+            fieldNames.add("probMap");
+            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
+
+            fieldNames.add("topics");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+            fieldNames.add("alpha");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+            fieldNames.add("delta");
+            fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+        }
+
+        @SuppressWarnings("deprecation")
+        @Override
+        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+            AggregationBuffer myAggr = new PLSAPredictAggregationBuffer();
+            reset(myAggr);
+            return myAggr;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            myAggr.reset();
+            myAggr.setOptions(topics, alpha, delta);
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+                Object[] parameters) throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+
+            if (parameters[0] == null || parameters[1] == null || parameters[2] == null
+                    || parameters[3] == null) {
+                return;
+            }
+
+            String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI);
+            float value = PrimitiveObjectInspectorUtils.getFloat(parameters[1], valueOI);
+            int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI);
+            float prob = PrimitiveObjectInspectorUtils.getFloat(parameters[3], probOI);
+
+            myAggr.iterate(word, value, label, prob);
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            if (myAggr.wcList.size() == 0) {
+                return null;
+            }
+
+            Object[] partialResult = new Object[5];
+            partialResult[0] = myAggr.wcList;
+            partialResult[1] = myAggr.probMap;
+            partialResult[2] = new IntWritable(myAggr.topics);
+            partialResult[3] = new FloatWritable(myAggr.alpha);
+            partialResult[4] = new DoubleWritable(myAggr.delta);
+
+            return partialResult;
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+                throws HiveException {
+            if (partial == null) {
+                return;
+            }
+
+            Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField);
+
+            List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj));
+
+            // fix list elements to Java String objects
+            int wcListSize = wcListRaw.size();
+            List<String> wcList = new ArrayList<String>();
+            for (int i = 0; i < wcListSize; i++) {
+                wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI));
+            }
+
+            Object probMapObj = internalMergeOI.getStructFieldData(partial, probMapField);
+            Map<?, ?> probMapRaw = probMapOI.getMap(HiveUtils.castLazyBinaryObject(probMapObj));
+
+            Map<String, List<Float>> probMap = new HashMap<String, List<Float>>();
+            for (Map.Entry<?, ?> e : probMapRaw.entrySet()) {
+                // fix map keys to Java String objects
+                String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), probMapKeyOI);
+
+                Object probMapValueObj = e.getValue();
+                List<?> probMapValueRaw = probMapValueOI.getList(HiveUtils.castLazyBinaryObject(probMapValueObj));
+
+                // fix map values to lists of Java Float objects
+                int probMapValueSize = probMapValueRaw.size();
+                List<Float> prob_word = new ArrayList<Float>();
+                for (int i = 0; i < probMapValueSize; i++) {
+                    prob_word.add(HiveUtils.getFloat(probMapValueRaw.get(i), probMapValueElemOI));
+                }
+
+                probMap.put(word, prob_word);
+            }
+
+            // restore options from partial result
+            Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+
+            Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
+            this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
+
+            Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField);
+            this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj);
+
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            myAggr.setOptions(topics, alpha, delta);
+            myAggr.merge(wcList, probMap);
+        }
+
+        @Override
+        public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+                throws HiveException {
+            PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
+            float[] topicDistr = myAggr.get();
+
+            SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>(
+                Collections.reverseOrder());
+            for (int i = 0; i < topicDistr.length; i++) {
+                sortedDistr.put(topicDistr[i], i);
+            }
+
+            List<Object[]> result = new ArrayList<Object[]>();
+            for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+                Object[] struct = new Object[2];
+                struct[0] = new IntWritable(e.getValue().intValue()); // label
+                struct[1] = new FloatWritable(e.getKey().floatValue()); // probability
+                result.add(struct);
+            }
+            return result;
+        }
+
+    }
+
+    public static class PLSAPredictAggregationBuffer extends
+            GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+        private List<String> wcList;
+        private Map<String, List<Float>> probMap;
+
+        private int topics;
+        private float alpha;
+        private double delta;
+
+        PLSAPredictAggregationBuffer() {
+            super();
+        }
+
+        void setOptions(int topics, float alpha, double delta) {
+            this.topics = topics;
+            this.alpha = alpha;
+            this.delta = delta;
+        }
+
+        void reset() {
+            this.wcList = new ArrayList<String>();
+            this.probMap = new HashMap<String, List<Float>>();
+        }
+
+        void iterate(@Nonnull final String word, final float value, final int label,
+                final float prob) {
+            wcList.add(word + ":" + value);
+
+            // for an unforeseen word, initialize its probs w/ -1s
+            List<Float> prob_word = probMap.get(word);
+
+            if (prob_word == null) {
+                prob_word = new ArrayList<Float>(Collections.nCopies(topics, -1.f));
+                probMap.put(word, prob_word);
+            }
+
+            // set the given prob value
+            prob_word.set(label, Float.valueOf(prob));
+        }
+
+        void merge(@Nonnull final List<String> o_wcList,
+                @Nonnull final Map<String, List<Float>> o_probMap) {
+            wcList.addAll(o_wcList);
+
+            for (Map.Entry<String, List<Float>> e : o_probMap.entrySet()) {
+                String o_word = e.getKey();
+                List<Float> o_prob_word = e.getValue();
+
+                final List<Float> prob_word = probMap.get(o_word);
+                if (prob_word == null) {// for a partially observed word
+                    probMap.put(o_word, o_prob_word);
+                } else { // for an unforeseen word
+                    for (int k = 0; k < topics; k++) {
+                        final float prob_k = o_prob_word.get(k).floatValue();
+                        if (prob_k != -1.f) { // not default value
+                            prob_word.set(k, prob_k); // set the partial prob value
+                        }
+                    }
+                    probMap.put(o_word, prob_word);
+                }
+            }
+        }
+
+        float[] get() {
+            final IncrementalPLSAModel model = new IncrementalPLSAModel(topics, alpha, delta);
+
+            for (Map.Entry<String, List<Float>> e : probMap.entrySet()) {
+                final String word = e.getKey();
+                final List<Float> prob_word = e.getValue();
+                for (int k = 0; k < topics; k++) {
+                    final float prob_k = prob_word.get(k).floatValue();
+                    if (prob_k != -1.f) {
+                        model.setProbability(word, k, prob_k);
+                    }
+                }
+            }
+
+            String[] wcArray = wcList.toArray(new String[wcList.size()]);
+            return model.getTopicDistribution(wcArray);
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
new file mode 100644
index 0000000..2616133
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -0,0 +1,535 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
+        + " - Returns a relation consists of <int topic, string word, float score>")
+public class PLSAUDTF extends UDTFWithOptions {
+    private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+    
+    public static final int DEFAULT_TOPICS = 10;
+    public static final float DEFAULT_ALPHA = 0.5f;
+    public static final double DEFAULT_DELTA = 1E-3d;
+    
+    // Options
+    protected int topics;
+    protected float alpha;
+    protected int iterations;
+    protected double delta;
+    protected double eps;
+    protected int miniBatchSize;
+
+    // number of proceeded training samples
+    protected long count;
+
+    protected String[][] miniBatch;
+    protected int miniBatchCount;
+
+    protected transient IncrementalPLSAModel model;
+
+    protected ListObjectInspector wordCountsOI;
+
+    // for iterations
+    protected NioStatefullSegment fileIO;
+    protected ByteBuffer inputBuf;
+
+    public PLSAUDTF() {
+        this.topics = DEFAULT_TOPICS;
+        this.alpha = DEFAULT_ALPHA;
+        this.iterations = 10;
+        this.delta = DEFAULT_DELTA;
+        this.eps = 1E-1d;
+        this.miniBatchSize = 128;
+    }
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+        opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
+        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
+        opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
+        opts.addOption("eps", "epsilon", true,
+            "Check convergence based on the difference of perplexity [default: 1E-1]");
+        opts.addOption("s", "mini_batch_size", true,
+            "Repeat model updating per mini-batch [default: 128]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+        CommandLine cl = null;
+
+        if (argOIs.length >= 2) {
+            String rawArgs = HiveUtils.getConstString(argOIs[1]);
+            cl = parseOptions(rawArgs);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA);
+            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+            if (iterations < 1) {
+                throw new UDFArgumentException(
+                    "'-iterations' must be greater than or equals to 1: " + iterations);
+            }
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
+            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
+            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+        }
+
+        return cl;
+    }
+
+    @Override
+    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length < 1) {
+            throw new UDFArgumentException(
+                "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
+        }
+
+        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+        processOptions(argOIs);
+
+        this.model = null;
+        this.count = 0L;
+        this.miniBatch = new String[miniBatchSize][];
+        this.miniBatchCount = 0;
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        fieldNames.add("topic");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+        fieldNames.add("word");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+        fieldNames.add("score");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+    }
+
+    protected void initModel() {
+        this.model = new IncrementalPLSAModel(topics, alpha, delta);
+    }
+
+    @Override
+    public void process(Object[] args) throws HiveException {
+        if (model == null) {
+            initModel();
+        }
+
+        int length = wordCountsOI.getListLength(args[0]);
+        String[] wordCounts = new String[length];
+        int j = 0;
+        for (int i = 0; i < length; i++) {
+            Object o = wordCountsOI.getListElement(args[0], i);
+            if (o == null) {
+                throw new HiveException("Given feature vector contains invalid elements");
+            }
+            String s = o.toString();
+            wordCounts[j] = s;
+            j++;
+        }
+        if (j == 0) {// avoid empty documents
+            return;
+        }
+
+        count++;
+
+        recordTrainSampleToTempFile(wordCounts);
+
+        miniBatch[miniBatchCount] = wordCounts;
+        miniBatchCount++;
+
+        if (miniBatchCount == miniBatchSize) {
+            model.train(miniBatch);
+            Arrays.fill(miniBatch, null); // clear
+            miniBatchCount = 0;
+        }
+    }
+
+    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
+            throws HiveException {
+        if (iterations == 1) {
+            return;
+        }
+
+        ByteBuffer buf = inputBuf;
+        NioStatefullSegment dst = fileIO;
+
+        if (buf == null) {
+            final File file;
+            try {
+                file = File.createTempFile("hivemall_plsa", ".sgmt");
+                file.deleteOnExit();
+                if (!file.canWrite()) {
+                    throw new UDFArgumentException("Cannot write a temporary file: "
+                            + file.getAbsolutePath());
+                }
+                logger.info("Record training samples to a file: " + file.getAbsolutePath());
+            } catch (IOException ioe) {
+                throw new UDFArgumentException(ioe);
+            } catch (Throwable e) {
+                throw new UDFArgumentException(e);
+            }
+            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
+            this.fileIO = dst = new NioStatefullSegment(file, false);
+        }
+
+        int wcLength = 0;
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            wcLength += wc.getBytes().length;
+        }
+        // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
+        int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength;
+        int remain = buf.remaining();
+        if (remain < recordBytes) {
+            writeBuffer(buf, dst);
+        }
+
+        buf.putInt(recordBytes);
+        buf.putInt(wordCounts.length);
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            buf.putInt(wc.length());
+            buf.put(wc.getBytes());
+        }
+    }
+
+    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
+            throws HiveException {
+        srcBuf.flip();
+        try {
+            dst.write(srcBuf);
+        } catch (IOException e) {
+            throw new HiveException("Exception causes while writing a buffer to file", e);
+        }
+        srcBuf.clear();
+    }
+
+    @Override
+    public void close() throws HiveException {
+        if (count == 0) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+        forwardModel();
+        this.model = null;
+    }
+
+    protected final void runIterativeTraining(@Nonnegative final int iterations)
+            throws HiveException {
+        final ByteBuffer buf = this.inputBuf;
+        final NioStatefullSegment dst = this.fileIO;
+        assert (buf != null);
+        assert (dst != null);
+        final long numTrainingExamples = count;
+
+        final Reporter reporter = getReporter();
+        final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+            "hivemall.plsa.IncrementalPLSA$Counter", "iteration");
+
+        try {
+            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+                if (buf.position() == 0) {
+                    return; // no training example
+                }
+                buf.flip();
+
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    reportProgress(reporter);
+                    setCounterValue(iterCounter, iter);
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    while (buf.remaining() > 0) {
+                        int recordBytes = buf.getInt();
+                        assert (recordBytes > 0) : recordBytes;
+                        int wcLength = buf.getInt();
+                        final String[] wordCounts = new String[wcLength];
+                        for (int j = 0; j < wcLength; j++) {
+                            int len = buf.getInt();
+                            byte[] bytes = new byte[len];
+                            buf.get(bytes);
+                            wordCounts[j] = new String(bytes);
+                        }
+
+                        miniBatch[miniBatchCount] = wordCounts;
+                        miniBatchCount++;
+
+                        if (miniBatchCount == miniBatchSize) {
+                            model.train(miniBatch);
+                            perplexity += model.computePerplexity();
+                            numTrain++;
+
+                            Arrays.fill(miniBatch, null); // clear
+                            miniBatchCount = 0;
+                        }
+                    }
+                    buf.rewind();
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on memory (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total) ");
+            } else {// read training examples in the temporary file and invoke train for each example
+
+                // write training examples in buffer to a temporary file
+                if (buf.remaining() > 0) {
+                    writeBuffer(buf, dst);
+                }
+                try {
+                    dst.flush();
+                } catch (IOException e) {
+                    throw new HiveException("Failed to flush a file: "
+                            + dst.getFile().getAbsolutePath(), e);
+                }
+                if (logger.isInfoEnabled()) {
+                    File tmpFile = dst.getFile();
+                    logger.info("Wrote " + numTrainingExamples
+                            + " records to a temporary file for iterative training: "
+                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+                            + ")");
+                }
+
+                // run iterations
+                int iter = 2;
+                float perplexityPrev = Float.MAX_VALUE;
+                float perplexity;
+                int numTrain;
+                for (; iter <= iterations; iter++) {
+                    perplexity = 0.f;
+                    numTrain = 0;
+
+                    Arrays.fill(miniBatch, null); // clear
+                    miniBatchCount = 0;
+
+                    setCounterValue(iterCounter, iter);
+
+                    buf.clear();
+                    dst.resetPosition();
+                    while (true) {
+                        reportProgress(reporter);
+                        // TODO prefetch
+                        // writes training examples to a buffer in the temporary file
+                        final int bytesRead;
+                        try {
+                            bytesRead = dst.read(buf);
+                        } catch (IOException e) {
+                            throw new HiveException("Failed to read a file: "
+                                    + dst.getFile().getAbsolutePath(), e);
+                        }
+                        if (bytesRead == 0) { // reached file EOF
+                            break;
+                        }
+                        assert (bytesRead > 0) : bytesRead;
+
+                        // reads training examples from a buffer
+                        buf.flip();
+                        int remain = buf.remaining();
+                        if (remain < SizeOf.INT) {
+                            throw new HiveException("Illegal file format was detected");
+                        }
+                        while (remain >= SizeOf.INT) {
+                            int pos = buf.position();
+                            int recordBytes = buf.getInt();
+                            remain -= SizeOf.INT;
+                            if (remain < recordBytes) {
+                                buf.position(pos);
+                                break;
+                            }
+
+                            int wcLength = buf.getInt();
+                            final String[] wordCounts = new String[wcLength];
+                            for (int j = 0; j < wcLength; j++) {
+                                int len = buf.getInt();
+                                byte[] bytes = new byte[len];
+                                buf.get(bytes);
+                                wordCounts[j] = new String(bytes);
+                            }
+
+                            miniBatch[miniBatchCount] = wordCounts;
+                            miniBatchCount++;
+
+                            if (miniBatchCount == miniBatchSize) {
+                                model.train(miniBatch);
+                                perplexity += model.computePerplexity();
+                                numTrain++;
+
+                                Arrays.fill(miniBatch, null); // clear
+                                miniBatchCount = 0;
+                            }
+
+                            remain -= recordBytes;
+                        }
+                        buf.compact();
+                    }
+
+                    // update for remaining samples
+                    if (miniBatchCount > 0) { // update for remaining samples
+                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+                        perplexity += model.computePerplexity();
+                        numTrain++;
+                    }
+
+                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
+                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                    perplexityPrev = perplexity;
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on a secondary storage (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total)");
+            }
+        } finally {
+            // delete the temporary file and release resources
+            try {
+                dst.close(true);
+            } catch (IOException e) {
+                throw new HiveException("Failed to close a file: "
+                        + dst.getFile().getAbsolutePath(), e);
+            }
+            this.inputBuf = null;
+            this.fileIO = null;
+        }
+    }
+
+    protected void forwardModel() throws HiveException {
+        final IntWritable topicIdx = new IntWritable();
+        final Text word = new Text();
+        final FloatWritable score = new FloatWritable();
+
+        final Object[] forwardObjs = new Object[3];
+        forwardObjs[0] = topicIdx;
+        forwardObjs[1] = word;
+        forwardObjs[2] = score;
+
+        for (int k = 0; k < topics; k++) {
+            topicIdx.set(k);
+
+            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
+            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+                score.set(e.getKey());
+                List<String> words = e.getValue();
+                for (int i = 0; i < words.size(); i++) {
+                    word.set(words.get(i));
+                    forward(forwardObjs);
+                }
+            }
+        }
+
+        logger.info("Forwarded topic words each of " + topics + " topics");
+    }
+
+    /*
+     * For testing:
+     */
+
+    @VisibleForTesting
+    double getProbability(String label, int k) {
+        return model.getProbability(label, k);
+    }
+
+    @VisibleForTesting
+    SortedMap<Float, List<String>> getTopicWords(int k) {
+        return model.getTopicWords(k);
+    }
+
+    @VisibleForTesting
+    float[] getTopicDistribution(@Nonnull String[] doc) {
+        return model.getTopicDistribution(doc);
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index c20c363..4177d70 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -735,4 +735,14 @@ public final class ArrayUtils {
         return ret;
     }
 
+    @Nonnull
+    public static float[] newRandomFloatArray(@Nonnegative final int size,
+            @Nonnull final Random rnd) {
+        final float[] ret = new float[size];
+        for (int i = 0; i < size; i++) {
+            ret[i] = rnd.nextFloat();
+        }
+        return ret;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 061b75d..8ffb89c 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -408,4 +408,16 @@ public final class MathUtils {
         return Math.log(logsumexp) + max;
     }
 
+    @Nonnull
+    public static float[] l1normalize(@Nonnull final float[] arr) {
+        double sum = 0.d;
+        for (int i = 0; i < arr.length; i++) {
+            sum += Math.abs(arr[i]);
+        }
+        for (int i = 0; i < arr.length; i++) {
+            arr[i] /= sum;
+        }
+        return arr;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
new file mode 100644
index 0000000..db34a38
--- /dev/null
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
+
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+
+public class IncrementalPLSAModelTest {
+    private static final boolean DEBUG = false;
+
+    @Test
+    public void testOnline() {
+        int K = 2;
+        int it = 0;
+        int maxIter = 1024;
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
+                "oranges:1"};
+
+        do {
+            perplexityPrev = perplexity;
+            perplexity = 0.f;
+
+            // online (i.e., one-by-one) updating
+            model.train(new String[][] {doc1});
+            perplexity += model.computePerplexity();
+
+            model.train(new String[][] {doc2});
+            perplexity += model.computePerplexity();
+
+            perplexity /= 2.f; // mean perplexity for the 2 docs
+
+            it++;
+            println("Iteration " + it + ": mean perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f);
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = model.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = model.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+
+        int k1, k2;
+        float[] topicDistr = model.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+                + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+                + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+    }
+
+    @Test
+    public void testMiniBatch() {
+        int K = 2;
+        int it = 0;
+        int maxIter = 2048;
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+
+        String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
+        String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
+                "oranges:1"};
+
+        do {
+            perplexityPrev = perplexity;
+
+            model.train(new String[][] {doc1, doc2});
+            perplexity = model.computePerplexity();
+
+            it++;
+            println("Iteration " + it + ": perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f);
+
+        SortedMap<Float, List<String>> topicWords;
+
+        println("Topic 0:");
+        println("========");
+        topicWords = model.getTopicWords(0);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+        println("Topic 1:");
+        println("========");
+        topicWords = model.getTopicWords(1);
+        for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+            List<String> words = e.getValue();
+            for (int i = 0; i < words.size(); i++) {
+                println(e.getKey() + " " + words.get(i));
+            }
+        }
+        println("========");
+
+
+        int k1, k2;
+        float[] topicDistr = model.getTopicDistribution(doc1);
+        if (topicDistr[0] > topicDistr[1]) {
+            // topic 0 MUST represent doc#1
+            k1 = 0;
+            k2 = 1;
+        } else {
+            k1 = 1;
+            k2 = 0;
+        }
+        Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+                + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
+            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+        Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+                + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
+            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+    }
+
+    @Test
+    public void testNews20() throws IOException {
+        int K = 20;
+        int miniBatchSize = 2;
+
+        int cnt, it;
+        int maxIter = 64;
+
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.8f, 1E-5d);
+
+        BufferedReader news20 = readFile("news20-multiclass.gz");
+
+        String[][] docs = new String[K][];
+
+        String line = news20.readLine();
+        List<String> doc = new ArrayList<String>();
+
+        cnt = 0;
+        while (line != null) {
+            StringTokenizer tokens = new StringTokenizer(line, " ");
+
+            int k = Integer.parseInt(tokens.nextToken()) - 1;
+
+            while (tokens.hasMoreTokens()) {
+                doc.add(tokens.nextToken());
+            }
+
+            // store first document in each of K classes
+            if (docs[k] == null) {
+                docs[k] = doc.toArray(new String[doc.size()]);
+                cnt++;
+            }
+
+            if (cnt == K) {
+                break;
+            }
+
+            doc.clear();
+            line = news20.readLine();
+        }
+        println("Stored " + cnt + " docs. Start training w/ mini-batch size: " + miniBatchSize);
+
+        float perplexityPrev;
+        float perplexity = Float.MAX_VALUE;
+
+        it = 0;
+        do {
+            perplexityPrev = perplexity;
+            perplexity = 0.f;
+
+            int head = 0;
+            cnt = 0;
+            while (head < K) {
+                int tail = head + miniBatchSize;
+                model.train(Arrays.copyOfRange(docs, head, tail));
+                perplexity += model.computePerplexity();
+                head = tail;
+                cnt++;
+                println("Processed mini-batch#" + cnt);
+            }
+
+            perplexity /= cnt;
+
+            it++;
+            println("Iteration " + it + ": mean perplexity = " + perplexity);
+        } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-3f);
+
+        Set<Integer> topics = new HashSet<Integer>();
+        for (int k = 0; k < K; k++) {
+            topics.add(findMaxTopic(model.getTopicDistribution(docs[k])));
+        }
+
+        int n = topics.size();
+        println("# of unique topics: " + n);
+        Assert.assertTrue("At least 15 documents SHOULD be classified to different topics, "
+                + "but there are only " + n + " unique topics.", n >= 15);
+    }
+
+    private static void println(String msg) {
+        if (DEBUG) {
+            System.out.println(msg);
+        }
+    }
+
+    @Nonnull
+    private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+        // use data stored for KPA UDTF test
+        InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName);
+        if (fileName.endsWith(".gz")) {
+            is = new GZIPInputStream(is);
+        }
+        return new BufferedReader(new InputStreamReader(is));
+    }
+
+    @Nonnull
+    private static int findMaxTopic(@Nonnull float[] topicDistr) {
+        int maxIdx = 0;
+        for (int i = 1; i < topicDistr.length; i++) {
+            if (topicDistr[maxIdx] < topicDistr[i]) {
+                maxIdx = i;
+            }
+        }
+        return maxIdx;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index a23d917..2c08560 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -100,7 +100,7 @@ public class LDAPredictUDAFTest {
                 PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
                         PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
                 ObjectInspectorUtils.getConstantObjectInspector(
-                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")};
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
 
         evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
 
@@ -117,7 +117,7 @@ public class LDAPredictUDAFTest {
                 ObjectInspectorFactory.getStandardListObjectInspector(
                         PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
 
-        fieldNames.add("topic");
+        fieldNames.add("topics");
         fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
 
         fieldNames.add("alpha");

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index d1e3f81..a5881d4 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -42,7 +42,7 @@ public class LDAUDTFTest {
         ObjectInspector[] argOIs = new ObjectInspector[] {
             ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
             ObjectInspectorUtils.getConstantObjectInspector(
-                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2 -s 1")};
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2 -num_docs 2 -s 1")};
 
         udtf.initialize(argOIs);