You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/05/09 08:05:36 UTC

incubator-hivemall git commit: Close #76: [HIVEMALL-74-2][HIVEMALL-91-2] Revise topic model UDFs

Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 211c28036 -> e27307898


Close #76: [HIVEMALL-74-2][HIVEMALL-91-2] Revise topic model UDFs


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e2730789
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e2730789
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e2730789

Branch: refs/heads/master
Commit: e273078982261d7e9eb5cd93cebe1ca8a3c1c5e9
Parents: 211c280
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Tue May 9 17:05:14 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Tue May 9 17:05:14 2017 +0900

----------------------------------------------------------------------
 .../topicmodel/IncrementalPLSAModel.java        |  21 ++-
 .../hivemall/topicmodel/LDAPredictUDAF.java     |  82 +++++-----
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  13 +-
 .../hivemall/topicmodel/PLSAPredictUDAF.java    |  36 +++--
 .../main/java/hivemall/topicmodel/PLSAUDTF.java |  11 +-
 .../java/hivemall/utils/lang/ArrayUtils.java    |   6 +-
 .../java/hivemall/utils/math/MathUtils.java     |   8 +-
 .../topicmodel/IncrementalPLSAModelTest.java    |   6 +-
 .../hivemall/topicmodel/LDAPredictUDAFTest.java | 151 +++++++++++--------
 .../topicmodel/PLSAPredictUDAFTest.java         | 129 ++++++++++------
 docs/gitbook/clustering/lda.md                  |  10 +-
 docs/gitbook/clustering/plsa.md                 |  16 +-
 12 files changed, 304 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index 745e510..a75febb 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,6 +20,8 @@ package hivemall.topicmodel;
 
 import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
 import static hivemall.utils.math.MathUtils.l1normalize;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.model.FeatureValue;
 import hivemall.utils.math.MathUtils;
 
@@ -29,7 +31,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Random;
 import java.util.SortedMap;
 import java.util.TreeMap;
 
@@ -54,7 +55,7 @@ public final class IncrementalPLSAModel {
 
     // random number generator
     @Nonnull
-    private final Random _rnd;
+    private final PRNG _rnd;
 
     // optimized in the E step
     private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
@@ -73,7 +74,7 @@ public final class IncrementalPLSAModel {
         this._alpha = alpha;
         this._delta = delta;
 
-        this._rnd = new Random(1001);
+        this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
 
         this._p_zw = new HashMap<String, float[]>();
 
@@ -92,7 +93,9 @@ public final class IncrementalPLSAModel {
         for (int d = 0; d < _miniBatchSize; d++) {
             do {
                 pPrev_dz.clear();
-                pPrev_dz.addAll(_p_dz);
+                for (float[] p_dz_d : _p_dz) { // deep copy
+                    pPrev_dz.add(p_dz_d.clone());
+                }
 
                 // Expectation
                 eStep(d);
@@ -216,6 +219,10 @@ public final class IncrementalPLSAModel {
                 for (int z = 0; z < _K; z++) {
                     p_zw_w[z] = n * p_dwz_dw[z] + _alpha * p_zw_w[z];
                 }
+            } else { // others
+                for (int z = 0; z < _K; z++) {
+                    p_zw_w[z] = _alpha * p_zw_w[z];
+                }
             }
 
             MathUtils.add(p_zw_w, sums, _K);
@@ -223,7 +230,7 @@ public final class IncrementalPLSAModel {
         // normalize to ensure \sum_w P(w|z) = 1
         for (float[] p_zw_w : _p_zw.values()) {
             for (int z = 0; z < _K; z++) {
-                p_zw_w[z] /= sums[z];
+                p_zw_w[z] = (float) (p_zw_w[z] / sums[z]);
             }
         }
     }
@@ -256,6 +263,10 @@ public final class IncrementalPLSAModel {
                     p_dw += (double) p_zw_w[z] * p_dz_d[z];
                 }
 
+                if (p_dw == 0.d) {
+                    throw new IllegalStateException("Perplexity would be Infinity. "
+                            + "Try different mini-batch size `-s`, larger `-delta` and/or larger `-alpha`.");
+                }
                 numer += w_value * Math.log(p_dw);
                 denom += w_value;
             }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index a4076b6..8d1edd8 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -120,7 +120,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         private StructObjectInspector internalMergeOI;
         private StructField wcListField;
         private StructField lambdaMapField;
-        private StructField topicOptionField;
+        private StructField topicsOptionField;
         private StructField alphaOptionField;
         private StructField deltaOptionField;
         private PrimitiveObjectInspector wcListElemOI;
@@ -134,7 +134,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
 
         protected Options getOptions() {
             Options opts = new Options();
-            opts.addOption("k", "topics", true, "The number of topics [required]");
+            opts.addOption("k", "topics", true, "The number of topics [default: 10]");
             opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
             opts.addOption("delta", true,
                 "Check convergence in the expectation step [default: 1E-5]");
@@ -175,22 +175,24 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
             CommandLine cl = null;
 
-            if (argOIs.length != 5) {
-                throw new UDFArgumentException("At least 1 option `-topics` MUST be specified");
-            }
+            if (argOIs.length >= 5) {
+                String rawArgs = HiveUtils.getConstString(argOIs[4]);
+                cl = parseOptions(rawArgs);
 
-            String rawArgs = HiveUtils.getConstString(argOIs[4]);
-            cl = parseOptions(rawArgs);
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), LDAUDTF.DEFAULT_TOPICS);
+                if (topics < 1) {
+                    throw new UDFArgumentException(
+                            "A positive integer MUST be set to an option `-topics`: " + topics);
+                }
 
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 0);
-            if (topics < 1) {
-                throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topics`: " + topics);
+                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
+                this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), LDAUDTF.DEFAULT_DELTA);
+            } else {
+                this.topics = LDAUDTF.DEFAULT_TOPICS;
+                this.alpha = 1.f / topics;
+                this.delta = LDAUDTF.DEFAULT_DELTA;
             }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d);
-
             return cl;
         }
 
@@ -211,7 +213,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.lambdaMapField = soi.getStructFieldRef("lambdaMap");
-                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.topicsOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -310,7 +312,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             Object[] partialResult = new Object[5];
             partialResult[0] = myAggr.wcList;
             partialResult[1] = myAggr.lambdaMap;
-            partialResult[2] = new IntWritable(myAggr.topic);
+            partialResult[2] = new IntWritable(myAggr.topics);
             partialResult[3] = new FloatWritable(myAggr.alpha);
             partialResult[4] = new DoubleWritable(myAggr.delta);
 
@@ -358,8 +360,8 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             }
 
             // restore options from partial result
-            Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
-            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            Object topicsObj = internalMergeOI.getStructFieldData(partial, topicsOptionField);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
             this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);
@@ -402,7 +404,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         private List<String> wcList;
         private Map<String, List<Float>> lambdaMap;
 
-        private int topic;
+        private int topics;
         private float alpha;
         private double delta;
 
@@ -410,8 +412,8 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
             super();
         }
 
-        void setOptions(int topic, float alpha, double delta) {
-            this.topic = topic;
+        void setOptions(int topics, float alpha, double delta) {
+            this.topics = topics;
             this.alpha = alpha;
             this.delta = delta;
         }
@@ -424,17 +426,16 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         void iterate(String word, float value, int label, float lambda) {
             wcList.add(word + ":" + value);
 
+            List<Float> lambda_word = lambdaMap.get(word);
+
             // for an unforeseen word, initialize its lambdas w/ -1s
-            if (!lambdaMap.containsKey(word)) {
-                List<Float> lambdaEmpty_word = new ArrayList<Float>(
-                    Collections.nCopies(topic, -1.f));
-                lambdaMap.put(word, lambdaEmpty_word);
+            if (lambda_word == null) {
+                lambda_word = new ArrayList<Float>(Collections.nCopies(topics, -1.f));
+                lambdaMap.put(word, lambda_word);
             }
 
             // set the given lambda value
-            List<Float> lambda_word = lambdaMap.get(word);
             lambda_word.set(label, lambda);
-            lambdaMap.put(word, lambda_word);
         }
 
         void merge(List<String> o_wcList, Map<String, List<Float>> o_lambdaMap) {
@@ -444,13 +445,14 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 String o_word = e.getKey();
                 List<Float> o_lambda_word = e.getValue();
 
-                if (!lambdaMap.containsKey(o_word)) { // for an unforeseen word
+                final List<Float> lambda_word = lambdaMap.get(o_word);
+                if (lambda_word == null) { // for an unforeseen word
                     lambdaMap.put(o_word, o_lambda_word);
                 } else { // for a partially observed word
-                    List<Float> lambda_word = lambdaMap.get(o_word);
-                    for (int k = 0; k < topic; k++) {
-                        if (o_lambda_word.get(k) != -1.f) { // not default value
-                            lambda_word.set(k, o_lambda_word.get(k)); // set the partial lambda value
+                    for (int k = 0; k < topics; k++) {
+                        final float lambda_k = o_lambda_word.get(k).floatValue();
+                        if (lambda_k != -1.f) { // not default value
+                            lambda_word.set(k, lambda_k); // set the partial lambda value
                         }
                     }
                     lambdaMap.put(o_word, lambda_word);
@@ -459,12 +461,16 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
         }
 
         float[] get() {
-            OnlineLDAModel model = new OnlineLDAModel(topic, alpha, delta);
-
-            for (String word : lambdaMap.keySet()) {
-                List<Float> lambda_word = lambdaMap.get(word);
-                for (int k = 0; k < topic; k++) {
-                    model.setLambda(word, k, lambda_word.get(k));
+            OnlineLDAModel model = new OnlineLDAModel(topics, alpha, delta);
+
+            for (Map.Entry<String, List<Float>> e : lambdaMap.entrySet()) {
+                final String word = e.getKey();
+                final List<Float> lambda_word = e.getValue();
+                for (int k = 0; k < topics; k++) {
+                    final float lambda_k = lambda_word.get(k).floatValue();
+                    if (lambda_k != -1.f) {
+                        model.setLambda(word, k, lambda_k);
+                    }
                 }
             }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 1e28a30..1cec875 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -62,6 +62,9 @@ import org.apache.hadoop.mapred.Reporter;
 public class LDAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(LDAUDTF.class);
 
+    public static final int DEFAULT_TOPICS = 10;
+    public static final double DEFAULT_DELTA = 1E-3d;
+
     // Options
     protected int topics;
     protected float alpha;
@@ -93,14 +96,14 @@ public class LDAUDTF extends UDTFWithOptions {
     protected ByteBuffer inputBuf;
 
     public LDAUDTF() {
-        this.topics = 10;
+        this.topics = DEFAULT_TOPICS;
         this.alpha = 1.f / topics;
         this.eta = 1.f / topics;
         this.numDocs = -1L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
         this.iterations = 10;
-        this.delta = 1E-3d;
+        this.delta = DEFAULT_DELTA;
         this.eps = 1E-1d;
         this.miniBatchSize = 128; // if 1, truly online setting
     }
@@ -131,7 +134,7 @@ public class LDAUDTF extends UDTFWithOptions {
         if (argOIs.length >= 2) {
             String rawArgs = HiveUtils.getConstString(argOIs[1]);
             cl = parseOptions(rawArgs);
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
             this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
             this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics);
             this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
@@ -148,7 +151,7 @@ public class LDAUDTF extends UDTFWithOptions {
                 throw new UDFArgumentException(
                     "'-iterations' must be greater than or equals to 1: " + iterations);
             }
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-3d);
+            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
             this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
             this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
         }
@@ -504,6 +507,8 @@ public class LDAUDTF extends UDTFWithOptions {
                         + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
                         + " training updates in total)");
             }
+        } catch (Throwable e) {
+            throw new HiveException("Exception caused in the iterative training", e);
         } finally {
             // delete the temporary file and release resources
             try {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index c0b60fc..08febb4 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -120,7 +120,7 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
         private StructObjectInspector internalMergeOI;
         private StructField wcListField;
         private StructField probMapField;
-        private StructField topicOptionField;
+        private StructField topicsOptionField;
         private StructField alphaOptionField;
         private StructField deltaOptionField;
         private PrimitiveObjectInspector wcListElemOI;
@@ -174,21 +174,25 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
 
         @Nullable
         protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
-            if (argOIs.length != 5) {
-                return null;
-            }
+            CommandLine cl = null;
 
-            String rawArgs = HiveUtils.getConstString(argOIs[4]);
-            CommandLine cl = parseOptions(rawArgs);
+            if (argOIs.length >= 5) {
+                String rawArgs = HiveUtils.getConstString(argOIs[4]);
+                cl = parseOptions(rawArgs);
 
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS);
-            if (topics < 1) {
-                throw new UDFArgumentException(
-                    "A positive integer MUST be set to an option `-topics`: " + topics);
-            }
+                this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS);
+                if (topics < 1) {
+                    throw new UDFArgumentException(
+                            "A positive integer MUST be set to an option `-topics`: " + topics);
+                }
 
-            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA);
-            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+                this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA);
+                this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA);
+            } else {
+                this.topics = PLSAUDTF.DEFAULT_TOPICS;
+                this.alpha = PLSAUDTF.DEFAULT_ALPHA;
+                this.delta = PLSAUDTF.DEFAULT_DELTA;
+            }
 
             return cl;
         }
@@ -210,7 +214,7 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
                 this.internalMergeOI = soi;
                 this.wcListField = soi.getStructFieldRef("wcList");
                 this.probMapField = soi.getStructFieldRef("probMap");
-                this.topicOptionField = soi.getStructFieldRef("topics");
+                this.topicsOptionField = soi.getStructFieldRef("topics");
                 this.alphaOptionField = soi.getStructFieldRef("alpha");
                 this.deltaOptionField = soi.getStructFieldRef("delta");
                 this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -356,8 +360,8 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
             }
 
             // restore options from partial result
-            Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField);
-            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj);
+            Object topicsObj = internalMergeOI.getStructFieldData(partial, topicsOptionField);
+            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicsObj);
 
             Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField);
             this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index 2616133..014356e 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -46,7 +46,10 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import org.apache.hadoop.io.FloatWritable;
 import org.apache.hadoop.io.IntWritable;
@@ -58,11 +61,11 @@ import org.apache.hadoop.mapred.Reporter;
         + " - Returns a relation consists of <int topic, string word, float score>")
 public class PLSAUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
-    
+
     public static final int DEFAULT_TOPICS = 10;
     public static final float DEFAULT_ALPHA = 0.5f;
     public static final double DEFAULT_DELTA = 1E-3d;
-    
+
     // Options
     protected int topics;
     protected float alpha;
@@ -474,6 +477,8 @@ public class PLSAUDTF extends UDTFWithOptions {
                         + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
                         + " training updates in total)");
             }
+        } catch (Throwable e) {
+            throw new HiveException("Exception caused in the iterative training", e);
         } finally {
             // delete the temporary file and release resources
             try {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 4177d70..540f1c6 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -18,6 +18,8 @@
  */
 package hivemall.utils.lang;
 
+import hivemall.math.random.PRNG;
+
 import java.lang.reflect.Array;
 import java.util.Arrays;
 import java.util.List;
@@ -737,10 +739,10 @@ public final class ArrayUtils {
 
     @Nonnull
     public static float[] newRandomFloatArray(@Nonnegative final int size,
-            @Nonnull final Random rnd) {
+            @Nonnull final PRNG rnd) {
         final float[] ret = new float[size];
         for (int i = 0; i < size; i++) {
-            ret[i] = rnd.nextFloat();
+            ret[i] = (float) rnd.nextDouble();
         }
         return ret;
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 8ffb89c..56c4f89 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -411,10 +411,14 @@ public final class MathUtils {
     @Nonnull
     public static float[] l1normalize(@Nonnull final float[] arr) {
         double sum = 0.d;
-        for (int i = 0; i < arr.length; i++) {
+        int size = arr.length;
+        for (int i = 0; i < size; i++) {
             sum += Math.abs(arr[i]);
         }
-        for (int i = 0; i < arr.length; i++) {
+        if (sum == 0.d) {
+            return new float[size];
+        }
+        for (int i = 0; i < size; i++) {
             arr[i] /= sum;
         }
         return arr;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
index db34a38..79be3a7 100644
--- a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -50,7 +50,7 @@ public class IncrementalPLSAModelTest {
         float perplexityPrev;
         float perplexity = Float.MAX_VALUE;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.5f, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
         String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
@@ -124,7 +124,7 @@ public class IncrementalPLSAModelTest {
         float perplexityPrev;
         float perplexity = Float.MAX_VALUE;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.5f, 1E-5d);
 
         String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"};
         String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2",
@@ -191,7 +191,7 @@ public class IncrementalPLSAModelTest {
         int cnt, it;
         int maxIter = 64;
 
-        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.8f, 1E-5d);
+        IncrementalPLSAModel model = new IncrementalPLSAModel(K, 100.f, 1E-3d);
 
         BufferedReader news20 = readFile("news20-multiclass.gz");
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index 2c08560..e09e57e 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -18,7 +18,8 @@
  */
 package hivemall.topicmodel;
 
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import hivemall.utils.math.MathUtils;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -46,64 +47,8 @@ public class LDAPredictUDAFTest {
     int[] labels;
     float[] lambdas;
 
-    @Test(expected=UDFArgumentException.class)
-    public void testWithoutOption() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
-    @Test(expected=UDFArgumentException.class)
-    public void testWithoutTopicOption() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
     @Before
     public void setUp() throws Exception {
-        udaf = new LDAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
-                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
         ArrayList<String> fieldNames = new ArrayList<String>();
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
 
@@ -129,8 +74,6 @@ public class LDAPredictUDAFTest {
         partialOI = new ObjectInspector[4];
         partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
-        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
-
         words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges", "like", "avocados", "colds",
             "colds", "avocados", "oranges", "like", "apples", "flu", "healthy", "vegetables", "fruits"};
         labels = new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1};
@@ -140,6 +83,24 @@ public class LDAPredictUDAFTest {
 
     @Test
     public void test() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc1 = new HashMap<String, Float>();
         doc1.put("fruits", 1.f);
         doc1.put("healthy", 1.f);
@@ -176,6 +137,24 @@ public class LDAPredictUDAFTest {
 
     @Test
     public void testMerge() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc = new HashMap<String, Float>();
         doc.put("apples", 1.f);
         doc.put("avocados", 1.f);
@@ -225,4 +204,58 @@ public class LDAPredictUDAFTest {
             Assert.assertTrue(distr[0] < distr[1]);
         }
     }
+
+    @Test
+    public void testUnmatchedTopics() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], lambdas[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], lambdas[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc1Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc1Distr), 1E-5d);
+
+        Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc2Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
index 456dd1d..2be48e1 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -18,7 +18,8 @@
  */
 package hivemall.topicmodel;
 
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import hivemall.utils.math.MathUtils;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -46,52 +47,8 @@ public class PLSAPredictUDAFTest {
     int[] labels;
     float[] probs;
 
-    @Test(expected = UDFArgumentException.class)
-    public void testWithoutOption() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
-    @Test(expected = UDFArgumentException.class)
-    public void testWithoutTopicOption() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-alpha 0.1")};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
-        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
-    }
-
     @Before
     public void setUp() throws Exception {
-        udaf = new PLSAPredictUDAF();
-
-        inputOIs = new ObjectInspector[] {
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
-                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
-                ObjectInspectorUtils.getConstantObjectInspector(
-                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
-
-        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
-
         ArrayList<String> fieldNames = new ArrayList<String>();
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
 
@@ -115,8 +72,6 @@ public class PLSAPredictUDAFTest {
         partialOI = new ObjectInspector[4];
         partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
-        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
-
         words = new String[] {"fruits", "vegetables", "healthy", "flu", "apples", "oranges",
                 "like", "avocados", "colds", "colds", "avocados", "oranges", "like", "apples",
                 "flu", "healthy", "vegetables", "fruits"};
@@ -129,6 +84,20 @@ public class PLSAPredictUDAFTest {
 
     @Test
     public void test() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc1 = new HashMap<String, Float>();
         doc1.put("fruits", 1.f);
         doc1.put("healthy", 1.f);
@@ -165,6 +134,20 @@ public class PLSAPredictUDAFTest {
 
     @Test
     public void testMerge() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
         final Map<String, Float> doc = new HashMap<String, Float>();
         doc.put("apples", 1.f);
         doc.put("avocados", 1.f);
@@ -214,4 +197,56 @@ public class PLSAPredictUDAFTest {
             Assert.assertTrue(distr[0] < distr[1]);
         }
     }
+
+    @Test
+    public void testUnmatchedTopics() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        // pre-defined topic model only has two topics, but prediction is launched with -topics=10 (default value)
+        inputOIs = new ObjectInspector[] {
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
+                PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT)};
+
+        evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+        final Map<String, Float> doc1 = new HashMap<String, Float>();
+        doc1.put("fruits", 1.f);
+        doc1.put("healthy", 1.f);
+        doc1.put("vegetables", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc1.get(word), labels[i], probs[i]});
+        }
+        float[] doc1Distr = agg.get();
+
+        final Map<String, Float> doc2 = new HashMap<String, Float>();
+        doc2.put("apples", 1.f);
+        doc2.put("avocados", 1.f);
+        doc2.put("colds", 1.f);
+        doc2.put("flu", 1.f);
+        doc2.put("like", 2.f);
+        doc2.put("oranges", 1.f);
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, doc2.get(word), labels[i], probs[i]});
+        }
+        float[] doc2Distr = agg.get();
+
+        Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc1Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc1Distr), 1E-5d);
+
+        Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc2Distr.length);
+        Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/docs/gitbook/clustering/lda.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/lda.md b/docs/gitbook/clustering/lda.md
index cc477da..8b8e5f5 100644
--- a/docs/gitbook/clustering/lda.md
+++ b/docs/gitbook/clustering/lda.md
@@ -82,7 +82,7 @@ with word_counts as (
     docid, word
 )
 select
-  train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+  train_lda(feature, "-topics 2 -iter 20") as (label, word, lambda)
 from (
   select docid, collect_set(word_count) as feature
   from word_counts
@@ -92,7 +92,7 @@ from (
 ;
 ```
 
-Here, an option `-topic 2` specifies the number of topics we assume in the set of documents.
+Here, an option `-topics 2` specifies the number of topics we assume in the set of documents.
 
 Notice that `order by docid` ensures building a LDA model precisely in a single node. In case that you like to launch `train_lda` in parallel, following query hopefully returns similar (but might be slightly approximated) result:
 
@@ -104,7 +104,7 @@ select
   label, word, avg(lambda) as lambda
 from (
   select
-    train_lda(feature, "-topic 2 -iter 20") as (label, word, lambda)
+    train_lda(feature, "-topics 2 -iter 20") as (label, word, lambda)
   from (
     select docid, collect_set(f) as feature
     from word_counts
@@ -163,7 +163,7 @@ with test as (
 )
 select
   t.docid,
-  lda_predict(t.word, t.value, m.label, m.lambda, "-topic 2") as probabilities
+  lda_predict(t.word, t.value, m.label, m.lambda, "-topics 2") as probabilities
 from
   test t
   JOIN lda_model m ON (t.word = m.word)
@@ -177,7 +177,7 @@ group by
 |1  | [{"label":0,"probability":0.875},{"label":1,"probability":0.125}]|
 |2  | [{"label":1,"probability":0.9375},{"label":0,"probability":0.0625}]|
 
-Importantly, an option `-topic` should be set to the same value as you set for training.
+Importantly, an option `-topics` is expected to be the same value as you set for training.
 
 Since the probabilities are sorted in descending order, a label of the most promising topic is easily obtained as:
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e2730789/docs/gitbook/clustering/plsa.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/clustering/plsa.md b/docs/gitbook/clustering/plsa.md
index 456dfe7..7cd3a9d 100644
--- a/docs/gitbook/clustering/plsa.md
+++ b/docs/gitbook/clustering/plsa.md
@@ -151,4 +151,18 @@ This value controls **how much iterative model update is affected by the old res
 
 From an algorithmic point of view, training pLSA (and LDA) iteratively repeats certain operations and updates the target value (i.e., probability obtained as a result of `train_plsa()`). This iterative procedure gradually makes the probabilities more accurate. What `alpha` does is to control the degree of the change of probabilities in each step.
 
-Normally, `alpha` is set to a small value from 0.0 to 0.5 (default is 0.5).
\ No newline at end of file
+Importantly, pLSA is likely to overfit single mini-batch. As a result, $$P(w|z)$$ could be particularly bad values (i.e., $$(w|z) = 0$$), and `train_plsa()` sometimes fails with an exception like:
+
+```
+Perplexity would be Infinity. Try different mini-batch size `-s`, larger `-delta` and/or larger `-alpha`.
+```
+
+In that case, you need to try different hyper-parameters to avoid overfitting as the exception suggests.
+
+For instance, [20 newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) which consists of 10906 realistic documents empirically requires the following options:
+
+```sql
+SELECT train_plsa(features, "-topics 20 -iter 10 -s 128 -delta 0.01 -alpha 512 -eps 0.1")
+```
+
+Clearly, `alpha` is much larger than `0.01` which was used for the dummy data above. Let you keep in mind that an appropriate value of `alpha` highly depends on the number of documents and mini-batch size.
\ No newline at end of file