You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/27 06:44:59 UTC

[1/3] incubator-hivemall git commit: Close #89: [HIVEMALL-120] Refactor on LDA/pLSA's mini-batch & buffered iteration logic

Repository: incubator-hivemall
Updated Branches:
  refs/heads/master bfc5b75b0 -> 9f01ebf20


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
index addacbc..e5045a5 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
@@ -54,7 +54,7 @@ public class PLSAUDTFTest {
         udtf.process(new Object[] {Arrays.asList(doc1)});
         udtf.process(new Object[] {Arrays.asList(doc2)});
 
-        udtf.closeWithoutModelReset();
+        udtf.finalizeTraining();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -93,10 +93,10 @@ public class PLSAUDTFTest {
 
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
-            udtf.getProbability("vegetables", k1) > udtf.getProbability("flu", k1));
+            udtf.getWordScore("vegetables", k1) > udtf.getWordScore("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
-            udtf.getProbability("avocados", k2) > udtf.getProbability("healthy", k2));
+            udtf.getWordScore("avocados", k2) > udtf.getWordScore("healthy", k2));
     }
 
     @Test
@@ -107,7 +107,7 @@ public class PLSAUDTFTest {
                 ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                 ObjectInspectorUtils.getConstantObjectInspector(
                     PrimitiveObjectInspectorFactory.javaStringObjectInspector,
-                    "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")};
+                    "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000 -mini_batch_size 1")};
 
         udtf.initialize(argOIs);
 
@@ -117,7 +117,7 @@ public class PLSAUDTFTest {
         udtf.process(new Object[] {Arrays.asList(doc1)});
         udtf.process(new Object[] {Arrays.asList(doc2)});
 
-        udtf.closeWithoutModelReset();
+        udtf.finalizeTraining();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -156,10 +156,10 @@ public class PLSAUDTFTest {
 
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `野菜` SHOULD be more suitable topic word than `インフルエンザ` in the topic",
-            udtf.getProbability("野菜", k1) > udtf.getProbability("インフルエンザ", k1));
+            udtf.getWordScore("野菜", k1) > udtf.getWordScore("インフルエンザ", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `アボカド` SHOULD be more suitable topic word than `健康` in the topic",
-            udtf.getProbability("アボカド", k2) > udtf.getProbability("健康", k2));
+            udtf.getWordScore("アボカド", k2) > udtf.getWordScore("健康", k2));
     }
 
     private static void println(String msg) {


[3/3] incubator-hivemall git commit: Applied refactoring for topicmodel module

Posted by my...@apache.org.
Applied refactoring for topicmodel module

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9f01ebf2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9f01ebf2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9f01ebf2

Branch: refs/heads/master
Commit: 9f01ebf20c74559be8a50d459103118a51c229bf
Parents: 0495ffa
Author: Makoto Yui <my...@apache.org>
Authored: Tue Jun 27 15:44:31 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 15:44:31 2017 +0900

----------------------------------------------------------------------
 .../AbstractProbabilisticTopicModel.java        | 26 +++++++++++++-------
 .../topicmodel/IncrementalPLSAModel.java        | 16 ++++++------
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  5 +---
 .../hivemall/topicmodel/OnlineLDAModel.java     | 18 +++++++-------
 .../main/java/hivemall/topicmodel/PLSAUDTF.java |  5 +---
 .../ProbabilisticTopicModelBaseUDTF.java        | 25 ++++++++++++-------
 6 files changed, 52 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
index 3c097e2..1b7f3e8 100644
--- a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -21,9 +21,14 @@ package hivemall.topicmodel;
 import hivemall.annotations.VisibleForTesting;
 import hivemall.model.FeatureValue;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
-import java.util.*;
 
 public abstract class AbstractProbabilisticTopicModel {
 
@@ -31,6 +36,7 @@ public abstract class AbstractProbabilisticTopicModel {
     protected final int _K;
 
     // total number of documents
+    @Nonnegative
     protected long _D;
 
     // for mini-batch
@@ -38,7 +44,7 @@ public abstract class AbstractProbabilisticTopicModel {
     protected final List<Map<String, Float>> _miniBatchDocs;
     protected int _miniBatchSize;
 
-    public AbstractProbabilisticTopicModel(int K) {
+    public AbstractProbabilisticTopicModel(@Nonnegative int K) {
         this._K = K;
         this._D = 0L;
         this._miniBatchDocs = new ArrayList<Map<String, Float>>();
@@ -73,26 +79,28 @@ public abstract class AbstractProbabilisticTopicModel {
         }
     }
 
-    public void accumulateDocCount() {
+    protected void accumulateDocCount() {
         this._D += 1;
     }
 
-    public long getDocCount() {
+    @Nonnegative
+    protected long getDocCount() {
         return _D;
     }
 
-    public abstract void train(@Nonnull final String[][] miniBatch);
+    protected abstract void train(@Nonnull final String[][] miniBatch);
 
-    public abstract float computePerplexity();
+    protected abstract float computePerplexity();
 
     @Nonnull
-    public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+    protected abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
 
     @Nonnull
-    public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+    protected abstract float[] getTopicDistribution(@Nonnull final String[] doc);
 
     @VisibleForTesting
     abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
 
-    public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+    protected abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic,
+            final float score);
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index b99e670..6419664 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,7 +20,6 @@ package hivemall.topicmodel;
 
 import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
 import static hivemall.utils.math.MathUtils.l1normalize;
-
 import hivemall.annotations.VisibleForTesting;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
@@ -59,9 +58,10 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
 
     // optimized in the M step
-    @Nonnull
     private List<float[]> _p_dz; // P(z|d) probability of topics for documents
-    private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+    @Nonnull
+    private final Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
 
     public IncrementalPLSAModel(int K, float alpha, double delta) {
         super(K);
@@ -74,7 +74,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         this._p_zw = new HashMap<String, float[]>();
     }
 
-    public void train(@Nonnull final String[][] miniBatch) {
+    protected void train(@Nonnull final String[][] miniBatch) {
         initMiniBatch(miniBatch, _miniBatchDocs);
 
         this._miniBatchSize = _miniBatchDocs.size();
@@ -211,7 +211,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         return (diff / _K) < _delta;
     }
 
-    public float computePerplexity() {
+    protected float computePerplexity() {
         double numer = 0.d;
         double denom = 0.d;
 
@@ -241,7 +241,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     }
 
     @Nonnull
-    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
         final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
             Collections.reverseOrder());
 
@@ -261,7 +261,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     }
 
     @Nonnull
-    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+    protected float[] getTopicDistribution(@Nonnull final String[] doc) {
         train(new String[][] {doc});
         return _p_dz.get(0);
     }
@@ -271,7 +271,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         return _p_zw.get(w)[z];
     }
 
-    public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+    protected void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
         float[] prob_label = _p_zw.get(w);
         if (prob_label == null) {
             prob_label = newRandomFloatArray(_K, _rnd);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 41386a4..9bac908 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 
 @Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
-    private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
 
     public static final double DEFAULT_DELTA = 1E-3d;
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 4a7531c..6a8d6db 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -38,6 +38,9 @@ import org.apache.commons.math3.special.Gamma;
 
 public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
 
+    private static final double SHAPE = 100.d;
+    private static final double SCALE = 1.d / SHAPE;
+
     // ---------------------------------
     // HyperParameters
 
@@ -72,7 +75,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     private final boolean _isAutoD;
 
     // parameters
-    @Nonnull
     private List<Map<String, float[]>> _phi;
     private float[][] _gamma;
     @Nonnull
@@ -81,8 +83,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     // random number generator
     @Nonnull
     private final GammaDistribution _gd;
-    private static final double SHAPE = 100.d;
-    private static final double SCALE = 1.d / SHAPE;
 
     // for computing perplexity
     private float _docRatio = 1.f;
@@ -121,7 +121,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Override
-    public void accumulateDocCount() {
+    protected void accumulateDocCount() {
         /*
          * In a truly online setting, total number of documents equals to the number of documents that have ever seen.
          * In that case, users need to manually set the current max number of documents via this method.
@@ -133,7 +133,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
         }
     }
 
-    public void train(@Nonnull final String[][] miniBatch) {
+    protected void train(@Nonnull final String[][] miniBatch) {
         preprocessMiniBatch(miniBatch);
 
         initParams(true);
@@ -341,7 +341,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     /**
      * Calculate approximate perplexity for the current mini-batch.
      */
-    public float computePerplexity() {
+    protected float computePerplexity() {
         double bound = computeApproxBound();
         double perWordBound = bound / (_docRatio * _valueSum);
         return (float) Math.exp(-1.d * perWordBound);
@@ -449,7 +449,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
         return lambda_label[k];
     }
 
-    public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
+    protected void setWordScore(@Nonnull final String label, @Nonnegative final int k,
             final float lambda_k) {
         float[] lambda_label = _lambda.get(label);
         if (lambda_label == null) {
@@ -460,7 +460,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Nonnull
-    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
+    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
         return getTopicWords(k, _lambda.keySet().size());
     }
 
@@ -501,7 +501,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Nonnull
-    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+    protected float[] getTopicDistribution(@Nonnull final String[] doc) {
         preprocessMiniBatch(new String[][] {doc});
 
         initParams(false);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index e1d8797..9c5a0ea 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 
 @Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
-    private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+public final class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
 
     public static final float DEFAULT_ALPHA = 0.5f;
     public static final double DEFAULT_DELTA = 1E-3d;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
index cff076e..c3dab89 100644
--- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -27,6 +27,19 @@ import hivemall.utils.io.NioStatefullSegment;
 import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
 import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
 import org.apache.commons.logging.Log;
@@ -44,13 +57,6 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.Counters;
 import org.apache.hadoop.mapred.Reporter;
 
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-
 public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
 
@@ -143,6 +149,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
         return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
     }
 
+    @Nonnull
     protected abstract AbstractProbabilisticTopicModel createModel();
 
     @Override
@@ -157,7 +164,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
         for (int i = 0; i < length; i++) {
             Object o = wordCountsOI.getListElement(args[0], i);
             if (o == null) {
-                throw new HiveException("Given feature vector contains invalid elements");
+                throw new HiveException("Given feature vector contains invalid null elements");
             }
             String s = o.toString();
             wordCounts[j] = s;
@@ -167,7 +174,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
             return;
         }
 
-        model.accumulateDocCount();;
+        model.accumulateDocCount();
 
         update(wordCounts);
 


[2/3] incubator-hivemall git commit: Close #89: [HIVEMALL-120] Refactor on LDA/pLSA's mini-batch & buffered iteration logic

Posted by my...@apache.org.
Close #89: [HIVEMALL-120] Refactor on LDA/pLSA's mini-batch & buffered iteration logic


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/0495ffad
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/0495ffad
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/0495ffad

Branch: refs/heads/master
Commit: 0495ffadbc42bffa36cb583622708ae1fa65a44e
Parents: bfc5b75
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Tue Jun 27 13:53:45 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 13:53:45 2017 +0900

----------------------------------------------------------------------
 .../AbstractProbabilisticTopicModel.java        |  98 ++++
 .../topicmodel/IncrementalPLSAModel.java        |  51 +-
 .../hivemall/topicmodel/LDAPredictUDAF.java     |   2 +-
 .../main/java/hivemall/topicmodel/LDAUDTF.java  | 503 +------------------
 .../hivemall/topicmodel/OnlineLDAModel.java     |  82 +--
 .../hivemall/topicmodel/PLSAPredictUDAF.java    |   2 +-
 .../main/java/hivemall/topicmodel/PLSAUDTF.java | 490 +-----------------
 .../ProbabilisticTopicModelBaseUDTF.java        | 487 ++++++++++++++++++
 .../topicmodel/IncrementalPLSAModelTest.java    |   8 +-
 .../java/hivemall/topicmodel/LDAUDTFTest.java   |  14 +-
 .../hivemall/topicmodel/OnlineLDAModelTest.java |   4 +-
 .../java/hivemall/topicmodel/PLSAUDTFTest.java  |  14 +-
 12 files changed, 653 insertions(+), 1102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
new file mode 100644
index 0000000..3c097e2
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.annotations.VisibleForTesting;
+import hivemall.model.FeatureValue;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.*;
+
+public abstract class AbstractProbabilisticTopicModel {
+
+    // number of topics
+    protected final int _K;
+
+    // total number of documents
+    protected long _D;
+
+    // for mini-batch
+    @Nonnull
+    protected final List<Map<String, Float>> _miniBatchDocs;
+    protected int _miniBatchSize;
+
+    public AbstractProbabilisticTopicModel(int K) {
+        this._K = K;
+        this._D = 0L;
+        this._miniBatchDocs = new ArrayList<Map<String, Float>>();
+    }
+
+    protected static void initMiniBatch(@Nonnull final String[][] miniBatch,
+            @Nonnull final List<Map<String, Float>> docs) {
+        docs.clear();
+
+        final FeatureValue probe = new FeatureValue();
+
+        // parse document
+        for (final String[] e : miniBatch) {
+            if (e == null || e.length == 0) {
+                continue;
+            }
+
+            final Map<String, Float> doc = new HashMap<String, Float>();
+
+            // parse features
+            for (String fv : e) {
+                if (fv == null) {
+                    continue;
+                }
+                FeatureValue.parseFeatureAsString(fv, probe);
+                String label = probe.getFeatureAsString();
+                float value = probe.getValueAsFloat();
+                doc.put(label, Float.valueOf(value));
+            }
+
+            docs.add(doc);
+        }
+    }
+
+    public void accumulateDocCount() {
+        this._D += 1;
+    }
+
+    public long getDocCount() {
+        return _D;
+    }
+
+    public abstract void train(@Nonnull final String[][] miniBatch);
+
+    public abstract float computePerplexity();
+
+    @Nonnull
+    public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+
+    @Nonnull
+    public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+
+    @VisibleForTesting
+    abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
+
+    public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index 6eef23e..b99e670 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,9 +20,10 @@ package hivemall.topicmodel;
 
 import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
 import static hivemall.utils.math.MathUtils.l1normalize;
+
+import hivemall.annotations.VisibleForTesting;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
-import hivemall.model.FeatureValue;
 import hivemall.utils.math.MathUtils;
 
 import java.util.ArrayList;
@@ -37,14 +38,11 @@ import java.util.TreeMap;
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 
-public final class IncrementalPLSAModel {
+public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel {
 
     // ---------------------------------
     // HyperParameters
 
-    // number of topics
-    private final int _K;
-
     // control how much P(w|z) update is affected by the last value
     private final float _alpha;
 
@@ -65,20 +63,15 @@ public final class IncrementalPLSAModel {
     private List<float[]> _p_dz; // P(z|d) probability of topics for documents
     private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
 
-    @Nonnull
-    private final List<Map<String, Float>> _miniBatchDocs;
-    private int _miniBatchSize;
-
     public IncrementalPLSAModel(int K, float alpha, double delta) {
-        this._K = K;
+        super(K);
+
         this._alpha = alpha;
         this._delta = delta;
 
         this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
 
         this._p_zw = new HashMap<String, float[]>();
-
-        this._miniBatchDocs = new ArrayList<Map<String, Float>>();
     }
 
     public void train(@Nonnull final String[][] miniBatch) {
@@ -106,35 +99,6 @@ public final class IncrementalPLSAModel {
         }
     }
 
-    private static void initMiniBatch(@Nonnull final String[][] miniBatch,
-            @Nonnull final List<Map<String, Float>> docs) {
-        docs.clear();
-
-        final FeatureValue probe = new FeatureValue();
-
-        // parse document
-        for (final String[] e : miniBatch) {
-            if (e == null || e.length == 0) {
-                continue;
-            }
-
-            final Map<String, Float> doc = new HashMap<String, Float>();
-
-            // parse features
-            for (String fv : e) {
-                if (fv == null) {
-                    continue;
-                }
-                FeatureValue.parseFeatureAsString(fv, probe);
-                String word = probe.getFeatureAsString();
-                float value = probe.getValueAsFloat();
-                doc.put(word, Float.valueOf(value));
-            }
-
-            docs.add(doc);
-        }
-    }
-
     private void initParams() {
         final List<float[]> p_dz = new ArrayList<float[]>();
         final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>();
@@ -302,11 +266,12 @@ public final class IncrementalPLSAModel {
         return _p_dz.get(0);
     }
 
-    public float getProbability(@Nonnull final String w, @Nonnegative final int z) {
+    @VisibleForTesting
+    float getWordScore(@Nonnull final String w, @Nonnegative final int z) {
         return _p_zw.get(w)[z];
     }
 
-    public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+    public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
         float[] prob_label = _p_zw.get(w);
         if (prob_label == null) {
             prob_label = newRandomFloatArray(_K, _rnd);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 03779b0..94d510a 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -471,7 +471,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
                 for (int k = 0; k < topics; k++) {
                     final float lambda_k = lambda_word.get(k).floatValue();
                     if (lambda_k != -1.f) {
-                        model.setLambda(word, k, lambda_k);
+                        model.setWordScore(word, k, lambda_k);
                     }
                 }
             }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index de57518..41386a4 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -18,27 +18,7 @@
  */
 package hivemall.topicmodel;
 
-import hivemall.UDTFWithOptions;
-import hivemall.annotations.VisibleForTesting;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.FileUtils;
-import hivemall.utils.io.NIOUtils;
-import hivemall.utils.io.NioStatefullSegment;
-import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
-import hivemall.utils.lang.SizeOf;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
@@ -46,96 +26,52 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.FloatWritable;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.Reporter;
 
 @Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends UDTFWithOptions {
+public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
     private static final Log logger = LogFactory.getLog(LDAUDTF.class);
 
-    public static final int DEFAULT_TOPICS = 10;
     public static final double DEFAULT_DELTA = 1E-3d;
 
     // Options
-    protected int topics;
     protected float alpha;
     protected float eta;
     protected long numDocs;
     protected double tau0;
     protected double kappa;
-    protected int iterations;
     protected double delta;
-    protected double eps;
-    protected int miniBatchSize;
-
-    // if `num_docs` option is not given, this flag will be true
-    // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
-    protected boolean isAutoD;
-
-    // number of proceeded training samples
-    protected long count;
-
-    protected String[][] miniBatch;
-    protected int miniBatchCount;
-
-    protected transient OnlineLDAModel model;
-
-    protected ListObjectInspector wordCountsOI;
-
-    // for iterations
-    protected NioStatefullSegment fileIO;
-    protected ByteBuffer inputBuf;
 
     public LDAUDTF() {
-        this.topics = DEFAULT_TOPICS;
+        super();
+
         this.alpha = 1.f / topics;
         this.eta = 1.f / topics;
         this.numDocs = -1L;
         this.tau0 = 64.d;
         this.kappa = 0.7;
-        this.iterations = 10;
         this.delta = DEFAULT_DELTA;
-        this.eps = 1E-1d;
-        this.miniBatchSize = 128; // if 1, truly online setting
     }
 
     @Override
     protected Options getOptions() {
-        Options opts = new Options();
-        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+        Options opts = super.getOptions();
         opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
         opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
         opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]");
         opts.addOption("tau", "tau0", true,
             "The parameter which downweights early iterations [default: 64.0]");
         opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]");
-        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
         opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
-        opts.addOption("eps", "epsilon", true,
-            "Check convergence based on the difference of perplexity [default: 1E-1]");
-        opts.addOption("s", "mini_batch_size", true,
-            "Repeat model updating per mini-batch [default: 128]");
         return opts;
     }
 
     @Override
     protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
-        CommandLine cl = null;
+        CommandLine cl = super.processOptions(argOIs);
 
-        if (argOIs.length >= 2) {
-            String rawArgs = HiveUtils.getConstString(argOIs[1]);
-            cl = parseOptions(rawArgs);
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+        if (cl != null) {
             this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
             this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics);
             this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
@@ -147,436 +83,13 @@ public class LDAUDTF extends UDTFWithOptions {
             if (kappa <= 0.5 || kappa > 1.d) {
                 throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa);
             }
-            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
-            if (iterations < 1) {
-                throw new UDFArgumentException(
-                    "'-iterations' must be greater than or equals to 1: " + iterations);
-            }
             this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
-            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
-            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
         }
 
         return cl;
     }
 
-    @Override
-    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
-        if (argOIs.length < 1) {
-            throw new UDFArgumentException(
-                "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
-        }
-
-        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
-        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
-
-        processOptions(argOIs);
-
-        this.model = null;
-        this.count = 0L;
-        this.isAutoD = (numDocs < 0L);
-        this.miniBatch = new String[miniBatchSize][];
-        this.miniBatchCount = 0;
-
-        ArrayList<String> fieldNames = new ArrayList<String>();
-        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
-        fieldNames.add("topic");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
-        fieldNames.add("word");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
-        fieldNames.add("score");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
-
-        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
-    }
-
-    protected void initModel() {
-        this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
-    }
-
-    @Override
-    public void process(Object[] args) throws HiveException {
-        if (model == null) {
-            initModel();
-        }
-
-        final int length = wordCountsOI.getListLength(args[0]);
-        final String[] wordCounts = new String[length];
-        int j = 0;
-        for (int i = 0; i < length; i++) {
-            Object o = wordCountsOI.getListElement(args[0], i);
-            if (o == null) {
-                throw new HiveException("Given feature vector contains invalid elements");
-            }
-            String s = o.toString();
-            wordCounts[j] = s;
-            j++;
-        }
-        if (j == 0) {// avoid empty documents
-            return;
-        }
-
-        count++;
-        if (isAutoD) {
-            model.setNumTotalDocs(count);
-        }
-
-        recordTrainSampleToTempFile(wordCounts);
-
-        miniBatch[miniBatchCount] = wordCounts;
-        miniBatchCount++;
-
-        if (miniBatchCount == miniBatchSize) {
-            model.train(miniBatch);
-            Arrays.fill(miniBatch, null); // clear
-            miniBatchCount = 0;
-        }
-    }
-
-    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
-            throws HiveException {
-        if (iterations == 1) {
-            return;
-        }
-
-        ByteBuffer buf = inputBuf;
-        NioStatefullSegment dst = fileIO;
-
-        if (buf == null) {
-            final File file;
-            try {
-                file = File.createTempFile("hivemall_lda", ".sgmt");
-                file.deleteOnExit();
-                if (!file.canWrite()) {
-                    throw new UDFArgumentException("Cannot write a temporary file: "
-                            + file.getAbsolutePath());
-                }
-                logger.info("Record training samples to a file: " + file.getAbsolutePath());
-            } catch (IOException ioe) {
-                throw new UDFArgumentException(ioe);
-            } catch (Throwable e) {
-                throw new UDFArgumentException(e);
-            }
-            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
-            this.fileIO = dst = new NioStatefullSegment(file, false);
-        }
-
-        // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
-        int wcLengthTotal = 0;
-        for (String wc : wordCounts) {
-            if (wc == null) {
-                continue;
-            }
-            wcLengthTotal += wc.length();
-        }
-        int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal
-                * SizeOf.CHAR;
-
-        int remain = buf.remaining();
-        if (remain < requiredRecordBytes) {
-            writeBuffer(buf, dst);
-        }
-
-        buf.putInt(requiredRecordBytes);
-        buf.putInt(wordCounts.length);
-        for (String wc : wordCounts) {
-            NIOUtils.putString(wc, buf);
-        }
-    }
-
-    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
-            throws HiveException {
-        srcBuf.flip();
-        try {
-            dst.write(srcBuf);
-        } catch (IOException e) {
-            throw new HiveException("Exception causes while writing a buffer to file", e);
-        }
-        srcBuf.clear();
-    }
-
-    @Override
-    public void close() throws HiveException {
-        if (count == 0) {
-            this.model = null;
-            return;
-        }
-        if (miniBatchCount > 0) { // update for remaining samples
-            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-        }
-        if (iterations > 1) {
-            runIterativeTraining(iterations);
-        }
-        forwardModel();
-        this.model = null;
-    }
-
-    protected final void runIterativeTraining(@Nonnegative final int iterations)
-            throws HiveException {
-        final ByteBuffer buf = this.inputBuf;
-        final NioStatefullSegment dst = this.fileIO;
-        assert (buf != null);
-        assert (dst != null);
-        final long numTrainingExamples = count;
-
-        final Reporter reporter = getReporter();
-        final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
-            "hivemall.lda.OnlineLDA$Counter", "iteration");
-
-        try {
-            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
-                if (buf.position() == 0) {
-                    return; // no training example
-                }
-                buf.flip();
-
-                int iter = 2;
-                float perplexityPrev = Float.MAX_VALUE;
-                float perplexity;
-                int numTrain;
-                for (; iter <= iterations; iter++) {
-                    perplexity = 0.f;
-                    numTrain = 0;
-
-                    reportProgress(reporter);
-                    setCounterValue(iterCounter, iter);
-
-                    Arrays.fill(miniBatch, null); // clear
-                    miniBatchCount = 0;
-
-                    while (buf.remaining() > 0) {
-                        int recordBytes = buf.getInt();
-                        assert (recordBytes > 0) : recordBytes;
-                        int wcLength = buf.getInt();
-                        final String[] wordCounts = new String[wcLength];
-                        for (int j = 0; j < wcLength; j++) {
-                            wordCounts[j] = NIOUtils.getString(buf);
-                        }
-
-                        miniBatch[miniBatchCount] = wordCounts;
-                        miniBatchCount++;
-
-                        if (miniBatchCount == miniBatchSize) {
-                            model.train(miniBatch);
-                            perplexity += model.computePerplexity();
-                            numTrain++;
-
-                            Arrays.fill(miniBatch, null); // clear
-                            miniBatchCount = 0;
-                        }
-                    }
-                    buf.rewind();
-
-                    // update for remaining samples
-                    if (miniBatchCount > 0) { // update for remaining samples
-                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-                        perplexity += model.computePerplexity();
-                        numTrain++;
-                    }
-
-                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
-                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
-                    if (Math.abs(perplexityPrev - perplexity) < eps) {
-                        break;
-                    }
-                    perplexityPrev = perplexity;
-                }
-                logger.info("Performed "
-                        + Math.min(iter, iterations)
-                        + " iterations of "
-                        + NumberUtils.formatNumber(numTrainingExamples)
-                        + " training examples on memory (thus "
-                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
-                        + " training updates in total) ");
-            } else {// read training examples in the temporary file and invoke train for each example
-                // write training examples in buffer to a temporary file
-                if (buf.remaining() > 0) {
-                    writeBuffer(buf, dst);
-                }
-                try {
-                    dst.flush();
-                } catch (IOException e) {
-                    throw new HiveException("Failed to flush a file: "
-                            + dst.getFile().getAbsolutePath(), e);
-                }
-                if (logger.isInfoEnabled()) {
-                    File tmpFile = dst.getFile();
-                    logger.info("Wrote " + numTrainingExamples
-                            + " records to a temporary file for iterative training: "
-                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
-                            + ")");
-                }
-
-                // run iterations
-                int iter = 2;
-                float perplexityPrev = Float.MAX_VALUE;
-                float perplexity;
-                int numTrain;
-                for (; iter <= iterations; iter++) {
-                    perplexity = 0.f;
-                    numTrain = 0;
-
-                    Arrays.fill(miniBatch, null); // clear
-                    miniBatchCount = 0;
-
-                    setCounterValue(iterCounter, iter);
-
-                    buf.clear();
-                    dst.resetPosition();
-                    while (true) {
-                        reportProgress(reporter);
-                        // TODO prefetch
-                        // writes training examples to a buffer in the temporary file
-                        final int bytesRead;
-                        try {
-                            bytesRead = dst.read(buf);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to read a file: "
-                                    + dst.getFile().getAbsolutePath(), e);
-                        }
-                        if (bytesRead == 0) { // reached file EOF
-                            break;
-                        }
-                        assert (bytesRead > 0) : bytesRead;
-
-                        // reads training examples from a buffer
-                        buf.flip();
-                        int remain = buf.remaining();
-                        if (remain < SizeOf.INT) {
-                            throw new HiveException("Illegal file format was detected");
-                        }
-                        while (remain >= SizeOf.INT) {
-                            int pos = buf.position();
-                            int recordBytes = buf.getInt() - SizeOf.INT;
-                            remain -= SizeOf.INT;
-                            if (remain < recordBytes) {
-                                buf.position(pos);
-                                break;
-                            }
-
-                            int wcLength = buf.getInt();
-                            final String[] wordCounts = new String[wcLength];
-                            for (int j = 0; j < wcLength; j++) {
-                                wordCounts[j] = NIOUtils.getString(buf);
-                            }
-
-                            miniBatch[miniBatchCount] = wordCounts;
-                            miniBatchCount++;
-
-                            if (miniBatchCount == miniBatchSize) {
-                                model.train(miniBatch);
-                                perplexity += model.computePerplexity();
-                                numTrain++;
-
-                                Arrays.fill(miniBatch, null); // clear
-                                miniBatchCount = 0;
-                            }
-
-                            remain -= recordBytes;
-                        }
-                        buf.compact();
-                    }
-
-                    // update for remaining samples
-                    if (miniBatchCount > 0) { // update for remaining samples
-                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-                        perplexity += model.computePerplexity();
-                        numTrain++;
-                    }
-
-                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
-                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
-                    if (Math.abs(perplexityPrev - perplexity) < eps) {
-                        break;
-                    }
-                    perplexityPrev = perplexity;
-                }
-                logger.info("Performed "
-                        + Math.min(iter, iterations)
-                        + " iterations of "
-                        + NumberUtils.formatNumber(numTrainingExamples)
-                        + " training examples on a secondary storage (thus "
-                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
-                        + " training updates in total)");
-            }
-        } catch (Throwable e) {
-            throw new HiveException("Exception caused in the iterative training", e);
-        } finally {
-            // delete the temporary file and release resources
-            try {
-                dst.close(true);
-            } catch (IOException e) {
-                throw new HiveException("Failed to close a file: "
-                        + dst.getFile().getAbsolutePath(), e);
-            }
-            this.inputBuf = null;
-            this.fileIO = null;
-        }
-    }
-
-    protected void forwardModel() throws HiveException {
-        final IntWritable topicIdx = new IntWritable();
-        final Text word = new Text();
-        final FloatWritable score = new FloatWritable();
-
-        final Object[] forwardObjs = new Object[3];
-        forwardObjs[0] = topicIdx;
-        forwardObjs[1] = word;
-        forwardObjs[2] = score;
-
-        for (int k = 0; k < topics; k++) {
-            topicIdx.set(k);
-
-            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
-            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
-                score.set(e.getKey());
-                List<String> words = e.getValue();
-                for (int i = 0; i < words.size(); i++) {
-                    word.set(words.get(i));
-                    forward(forwardObjs);
-                }
-            }
-        }
-
-        logger.info("Forwarded topic words each of " + topics + " topics");
-    }
-
-    /*
-     * For testing:
-     */
-
-    @VisibleForTesting
-    void closeWithoutModelReset() throws HiveException {
-        // launch close(), but not forward & clear model
-        if (count == 0) {
-            this.model = null;
-            return;
-        }
-        if (miniBatchCount > 0) { // update for remaining samples
-            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-        }
-        if (iterations > 1) {
-            runIterativeTraining(iterations);
-        }
-    }
-
-    @VisibleForTesting
-    double getLambda(String label, int k) {
-        return model.getLambda(label, k);
-    }
-
-    @VisibleForTesting
-    SortedMap<Float, List<String>> getTopicWords(int k) {
-        return model.getTopicWords(k);
-    }
-
-    @VisibleForTesting
-    SortedMap<Float, List<String>> getTopicWords(int k, int topN) {
-        return model.getTopicWords(k, topN);
-    }
-
-    @VisibleForTesting
-    float[] getTopicDistribution(@Nonnull String[] doc) {
-        return model.getTopicDistribution(doc);
+    protected AbstractProbabilisticTopicModel createModel() {
+        return new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
     }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 8fef10c..4a7531c 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -19,7 +19,6 @@
 package hivemall.topicmodel;
 
 import hivemall.annotations.VisibleForTesting;
-import hivemall.model.FeatureValue;
 import hivemall.utils.lang.ArrayUtils;
 import hivemall.utils.math.MathUtils;
 
@@ -37,24 +36,17 @@ import javax.annotation.Nonnull;
 import org.apache.commons.math3.distribution.GammaDistribution;
 import org.apache.commons.math3.special.Gamma;
 
-public final class OnlineLDAModel {
+public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
 
     // ---------------------------------
     // HyperParameters
 
-    // number of topics
-    private final int _K;
-
     // prior on weight vectors "theta ~ Dir(alpha_)"
     private final float _alpha;
 
     // prior on topics "beta"
     private final float _eta;
 
-    // total number of documents
-    // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen
-    private long _D = -1L;
-
     // positive value which downweights early iterations
     @Nonnegative
     private final double _tau0;
@@ -75,6 +67,10 @@ public final class OnlineLDAModel {
     // controls how much old lambda is forgotten
     private double _rhot;
 
+    // if `num_docs` option is not given, this flag will be true
+    // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
+    private final boolean _isAutoD;
+
     // parameters
     @Nonnull
     private List<Map<String, float[]>> _phi;
@@ -88,11 +84,6 @@ public final class OnlineLDAModel {
     private static final double SHAPE = 100.d;
     private static final double SCALE = 1.d / SHAPE;
 
-    // for mini-batch
-    @Nonnull
-    private final List<Map<String, Float>> _miniBatchDocs;
-    private int _miniBatchSize;
-
     // for computing perplexity
     private float _docRatio = 1.f;
     private double _valueSum = 0.d;
@@ -103,6 +94,8 @@ public final class OnlineLDAModel {
 
     public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa,
             double delta) {
+        super(K);
+
         if (tau0 < 0.d) {
             throw new IllegalArgumentException("tau0 MUST be positive: " + tau0);
         }
@@ -110,7 +103,6 @@ public final class OnlineLDAModel {
             throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa);
         }
 
-        this._K = K;
         this._alpha = alpha;
         this._eta = eta;
         this._D = D;
@@ -118,31 +110,30 @@ public final class OnlineLDAModel {
         this._kappa = kappa;
         this._delta = delta;
 
+        this._isAutoD = (_D < 0L);
+
         // initialize a random number generator
         this._gd = new GammaDistribution(SHAPE, SCALE);
         _gd.reseedRandomGenerator(1001);
 
         // initialize the parameters
         this._lambda = new HashMap<String, float[]>(100);
-
-        this._miniBatchDocs = new ArrayList<Map<String, Float>>();
     }
 
-    /**
-     * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to
-     * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to
-     * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
-     */
-    public void setNumTotalDocs(@Nonnegative long D) {
-        this._D = D;
+    @Override
+    public void accumulateDocCount() {
+        /*
+         * In a truly online setting, total number of documents equals to the number of documents that have ever seen.
+         * In that case, users need to manually set the current max number of documents via this method.
+         * Note that, since the same set of documents could be repeatedly passed to `train()`,
+         * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
+         */
+        if (_isAutoD) {
+            this._D += 1;
+        }
     }
 
     public void train(@Nonnull final String[][] miniBatch) {
-        if (_D <= 0L) {
-            throw new IllegalStateException(
-                "Total number of documents MUST be set via `setNumTotalDocs()`");
-        }
-
         preprocessMiniBatch(miniBatch);
 
         initParams(true);
@@ -175,35 +166,6 @@ public final class OnlineLDAModel {
         this._docRatio = (float) ((double) _D / _miniBatchSize);
     }
 
-    private static void initMiniBatch(@Nonnull final String[][] miniBatch,
-            @Nonnull final List<Map<String, Float>> docs) {
-        docs.clear();
-
-        final FeatureValue probe = new FeatureValue();
-
-        // parse document
-        for (final String[] e : miniBatch) {
-            if (e == null || e.length == 0) {
-                continue;
-            }
-
-            final Map<String, Float> doc = new HashMap<String, Float>();
-
-            // parse features
-            for (String fv : e) {
-                if (fv == null) {
-                    continue;
-                }
-                FeatureValue.parseFeatureAsString(fv, probe);
-                String label = probe.getFeatureAsString();
-                float value = probe.getValueAsFloat();
-                doc.put(label, Float.valueOf(value));
-            }
-
-            docs.add(doc);
-        }
-    }
-
     private void initParams(final boolean gammaWithRandom) {
         final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>();
         final float[][] gamma = new float[_miniBatchSize][];
@@ -475,7 +437,7 @@ public final class OnlineLDAModel {
     }
 
     @VisibleForTesting
-    double getLambda(@Nonnull final String label, @Nonnegative final int k) {
+    float getWordScore(@Nonnull final String label, @Nonnegative final int k) {
         final float[] lambda_label = _lambda.get(label);
         if (lambda_label == null) {
             throw new IllegalArgumentException("Word `" + label + "` is not in the corpus.");
@@ -487,7 +449,7 @@ public final class OnlineLDAModel {
         return lambda_label[k];
     }
 
-    public void setLambda(@Nonnull final String label, @Nonnegative final int k,
+    public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
             final float lambda_k) {
         float[] lambda_label = _lambda.get(label);
         if (lambda_label == null) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index ff29236..7702945 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -474,7 +474,7 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
                 for (int k = 0; k < topics; k++) {
                     final float prob_k = prob_word.get(k).floatValue();
                     if (prob_k != -1.f) {
-                        model.setProbability(word, k, prob_k);
+                        model.setWordScore(word, k, prob_k);
                     }
                 }
             }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index 46f731f..e1d8797 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -18,27 +18,7 @@
  */
 package hivemall.topicmodel;
 
-import hivemall.UDTFWithOptions;
-import hivemall.annotations.VisibleForTesting;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.FileUtils;
-import hivemall.utils.io.NIOUtils;
-import hivemall.utils.io.NioStatefullSegment;
-import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
-import hivemall.utils.lang.SizeOf;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
@@ -46,503 +26,49 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.FloatWritable;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.Reporter;
 
 @Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends UDTFWithOptions {
+public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
     private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
 
-    public static final int DEFAULT_TOPICS = 10;
     public static final float DEFAULT_ALPHA = 0.5f;
     public static final double DEFAULT_DELTA = 1E-3d;
 
     // Options
-    protected int topics;
     protected float alpha;
-    protected int iterations;
     protected double delta;
-    protected double eps;
-    protected int miniBatchSize;
-
-    // number of proceeded training samples
-    protected long count;
-
-    protected String[][] miniBatch;
-    protected int miniBatchCount;
-
-    protected transient IncrementalPLSAModel model;
-
-    protected ListObjectInspector wordCountsOI;
-
-    // for iterations
-    protected NioStatefullSegment fileIO;
-    protected ByteBuffer inputBuf;
 
     public PLSAUDTF() {
-        this.topics = DEFAULT_TOPICS;
+        super();
+
         this.alpha = DEFAULT_ALPHA;
-        this.iterations = 10;
         this.delta = DEFAULT_DELTA;
-        this.eps = 1E-1d;
-        this.miniBatchSize = 128;
     }
 
     @Override
     protected Options getOptions() {
-        Options opts = new Options();
-        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+        Options opts = super.getOptions();
         opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
-        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
         opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
-        opts.addOption("eps", "epsilon", true,
-            "Check convergence based on the difference of perplexity [default: 1E-1]");
-        opts.addOption("s", "mini_batch_size", true,
-            "Repeat model updating per mini-batch [default: 128]");
         return opts;
     }
 
     @Override
     protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
-        CommandLine cl = null;
+        CommandLine cl = super.processOptions(argOIs);
 
-        if (argOIs.length >= 2) {
-            String rawArgs = HiveUtils.getConstString(argOIs[1]);
-            cl = parseOptions(rawArgs);
-            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+        if (cl != null) {
             this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA);
-            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
-            if (iterations < 1) {
-                throw new UDFArgumentException(
-                    "'-iterations' must be greater than or equals to 1: " + iterations);
-            }
             this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
-            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
-            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
         }
 
         return cl;
     }
 
-    @Override
-    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
-        if (argOIs.length < 1) {
-            throw new UDFArgumentException(
-                "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
-        }
-
-        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
-        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
-
-        processOptions(argOIs);
-
-        this.model = null;
-        this.count = 0L;
-        this.miniBatch = new String[miniBatchSize][];
-        this.miniBatchCount = 0;
-
-        ArrayList<String> fieldNames = new ArrayList<String>();
-        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
-        fieldNames.add("topic");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
-        fieldNames.add("word");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
-        fieldNames.add("score");
-        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
-
-        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
-    }
-
-    protected void initModel() {
-        this.model = new IncrementalPLSAModel(topics, alpha, delta);
-    }
-
-    @Override
-    public void process(Object[] args) throws HiveException {
-        if (model == null) {
-            initModel();
-        }
-
-        int length = wordCountsOI.getListLength(args[0]);
-        String[] wordCounts = new String[length];
-        int j = 0;
-        for (int i = 0; i < length; i++) {
-            Object o = wordCountsOI.getListElement(args[0], i);
-            if (o == null) {
-                throw new HiveException("Given feature vector contains invalid elements");
-            }
-            String s = o.toString();
-            wordCounts[j] = s;
-            j++;
-        }
-        if (j == 0) {// avoid empty documents
-            return;
-        }
-
-        count++;
-
-        recordTrainSampleToTempFile(wordCounts);
-
-        miniBatch[miniBatchCount] = wordCounts;
-        miniBatchCount++;
-
-        if (miniBatchCount == miniBatchSize) {
-            model.train(miniBatch);
-            Arrays.fill(miniBatch, null); // clear
-            miniBatchCount = 0;
-        }
-    }
-
-    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
-            throws HiveException {
-        if (iterations == 1) {
-            return;
-        }
-
-        ByteBuffer buf = inputBuf;
-        NioStatefullSegment dst = fileIO;
-
-        if (buf == null) {
-            final File file;
-            try {
-                file = File.createTempFile("hivemall_plsa", ".sgmt");
-                file.deleteOnExit();
-                if (!file.canWrite()) {
-                    throw new UDFArgumentException("Cannot write a temporary file: "
-                            + file.getAbsolutePath());
-                }
-                logger.info("Record training samples to a file: " + file.getAbsolutePath());
-            } catch (IOException ioe) {
-                throw new UDFArgumentException(ioe);
-            } catch (Throwable e) {
-                throw new UDFArgumentException(e);
-            }
-            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
-            this.fileIO = dst = new NioStatefullSegment(file, false);
-        }
-
-        // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
-        int wcLengthTotal = 0;
-        for (String wc : wordCounts) {
-            if (wc == null) {
-                continue;
-            }
-            wcLengthTotal += wc.length();
-        }
-        int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal
-                * SizeOf.CHAR;
-
-        int remain = buf.remaining();
-        if (remain < requiredRecordBytes) {
-            writeBuffer(buf, dst);
-        }
-
-        buf.putInt(requiredRecordBytes);
-        buf.putInt(wordCounts.length);
-        for (String wc : wordCounts) {
-            NIOUtils.putString(wc, buf);
-        }
-    }
-
-    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
-            throws HiveException {
-        srcBuf.flip();
-        try {
-            dst.write(srcBuf);
-        } catch (IOException e) {
-            throw new HiveException("Exception causes while writing a buffer to file", e);
-        }
-        srcBuf.clear();
+    protected AbstractProbabilisticTopicModel createModel() {
+        return new IncrementalPLSAModel(topics, alpha, delta);
     }
 
-    @Override
-    public void close() throws HiveException {
-        if (count == 0) {
-            this.model = null;
-            return;
-        }
-        if (miniBatchCount > 0) { // update for remaining samples
-            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-        }
-        if (iterations > 1) {
-            runIterativeTraining(iterations);
-        }
-        forwardModel();
-        this.model = null;
-    }
-
-    protected final void runIterativeTraining(@Nonnegative final int iterations)
-            throws HiveException {
-        final ByteBuffer buf = this.inputBuf;
-        final NioStatefullSegment dst = this.fileIO;
-        assert (buf != null);
-        assert (dst != null);
-        final long numTrainingExamples = count;
-
-        final Reporter reporter = getReporter();
-        final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
-            "hivemall.plsa.IncrementalPLSA$Counter", "iteration");
-
-        try {
-            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
-                if (buf.position() == 0) {
-                    return; // no training example
-                }
-                buf.flip();
-
-                int iter = 2;
-                float perplexityPrev = Float.MAX_VALUE;
-                float perplexity;
-                int numTrain;
-                for (; iter <= iterations; iter++) {
-                    perplexity = 0.f;
-                    numTrain = 0;
-
-                    reportProgress(reporter);
-                    setCounterValue(iterCounter, iter);
-
-                    Arrays.fill(miniBatch, null); // clear
-                    miniBatchCount = 0;
-
-                    while (buf.remaining() > 0) {
-                        int recordBytes = buf.getInt();
-                        assert (recordBytes > 0) : recordBytes;
-                        int wcLength = buf.getInt();
-                        final String[] wordCounts = new String[wcLength];
-                        for (int j = 0; j < wcLength; j++) {
-                            wordCounts[j] = NIOUtils.getString(buf);
-                        }
-
-                        miniBatch[miniBatchCount] = wordCounts;
-                        miniBatchCount++;
-
-                        if (miniBatchCount == miniBatchSize) {
-                            model.train(miniBatch);
-                            perplexity += model.computePerplexity();
-                            numTrain++;
-
-                            Arrays.fill(miniBatch, null); // clear
-                            miniBatchCount = 0;
-                        }
-                    }
-                    buf.rewind();
-
-                    // update for remaining samples
-                    if (miniBatchCount > 0) { // update for remaining samples
-                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-                        perplexity += model.computePerplexity();
-                        numTrain++;
-                    }
-
-                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
-                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
-                    if (Math.abs(perplexityPrev - perplexity) < eps) {
-                        break;
-                    }
-                    perplexityPrev = perplexity;
-                }
-                logger.info("Performed "
-                        + Math.min(iter, iterations)
-                        + " iterations of "
-                        + NumberUtils.formatNumber(numTrainingExamples)
-                        + " training examples on memory (thus "
-                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
-                        + " training updates in total) ");
-            } else {// read training examples in the temporary file and invoke train for each example
-
-                // write training examples in buffer to a temporary file
-                if (buf.remaining() > 0) {
-                    writeBuffer(buf, dst);
-                }
-                try {
-                    dst.flush();
-                } catch (IOException e) {
-                    throw new HiveException("Failed to flush a file: "
-                            + dst.getFile().getAbsolutePath(), e);
-                }
-                if (logger.isInfoEnabled()) {
-                    File tmpFile = dst.getFile();
-                    logger.info("Wrote " + numTrainingExamples
-                            + " records to a temporary file for iterative training: "
-                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
-                            + ")");
-                }
-
-                // run iterations
-                int iter = 2;
-                float perplexityPrev = Float.MAX_VALUE;
-                float perplexity;
-                int numTrain;
-                for (; iter <= iterations; iter++) {
-                    perplexity = 0.f;
-                    numTrain = 0;
-
-                    Arrays.fill(miniBatch, null); // clear
-                    miniBatchCount = 0;
-
-                    setCounterValue(iterCounter, iter);
-
-                    buf.clear();
-                    dst.resetPosition();
-                    while (true) {
-                        reportProgress(reporter);
-                        // TODO prefetch
-                        // writes training examples to a buffer in the temporary file
-                        final int bytesRead;
-                        try {
-                            bytesRead = dst.read(buf);
-                        } catch (IOException e) {
-                            throw new HiveException("Failed to read a file: "
-                                    + dst.getFile().getAbsolutePath(), e);
-                        }
-                        if (bytesRead == 0) { // reached file EOF
-                            break;
-                        }
-                        assert (bytesRead > 0) : bytesRead;
-
-                        // reads training examples from a buffer
-                        buf.flip();
-                        int remain = buf.remaining();
-                        if (remain < SizeOf.INT) {
-                            throw new HiveException("Illegal file format was detected");
-                        }
-                        while (remain >= SizeOf.INT) {
-                            int pos = buf.position();
-                            int recordBytes = buf.getInt() - SizeOf.INT;
-                            remain -= SizeOf.INT;
-                            if (remain < recordBytes) {
-                                buf.position(pos);
-                                break;
-                            }
-
-                            int wcLength = buf.getInt();
-                            final String[] wordCounts = new String[wcLength];
-                            for (int j = 0; j < wcLength; j++) {
-                                wordCounts[j] = NIOUtils.getString(buf);
-                            }
-
-                            miniBatch[miniBatchCount] = wordCounts;
-                            miniBatchCount++;
-
-                            if (miniBatchCount == miniBatchSize) {
-                                model.train(miniBatch);
-                                perplexity += model.computePerplexity();
-                                numTrain++;
-
-                                Arrays.fill(miniBatch, null); // clear
-                                miniBatchCount = 0;
-                            }
-
-                            remain -= recordBytes;
-                        }
-                        buf.compact();
-                    }
-
-                    // update for remaining samples
-                    if (miniBatchCount > 0) { // update for remaining samples
-                        model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-                        perplexity += model.computePerplexity();
-                        numTrain++;
-                    }
-
-                    logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
-                    perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
-                    if (Math.abs(perplexityPrev - perplexity) < eps) {
-                        break;
-                    }
-                    perplexityPrev = perplexity;
-                }
-                logger.info("Performed "
-                        + Math.min(iter, iterations)
-                        + " iterations of "
-                        + NumberUtils.formatNumber(numTrainingExamples)
-                        + " training examples on a secondary storage (thus "
-                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
-                        + " training updates in total)");
-            }
-        } catch (Throwable e) {
-            throw new HiveException("Exception caused in the iterative training", e);
-        } finally {
-            // delete the temporary file and release resources
-            try {
-                dst.close(true);
-            } catch (IOException e) {
-                throw new HiveException("Failed to close a file: "
-                        + dst.getFile().getAbsolutePath(), e);
-            }
-            this.inputBuf = null;
-            this.fileIO = null;
-        }
-    }
-
-    protected void forwardModel() throws HiveException {
-        final IntWritable topicIdx = new IntWritable();
-        final Text word = new Text();
-        final FloatWritable score = new FloatWritable();
-
-        final Object[] forwardObjs = new Object[3];
-        forwardObjs[0] = topicIdx;
-        forwardObjs[1] = word;
-        forwardObjs[2] = score;
-
-        for (int k = 0; k < topics; k++) {
-            topicIdx.set(k);
-
-            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
-            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
-                score.set(e.getKey());
-                List<String> words = e.getValue();
-                for (int i = 0; i < words.size(); i++) {
-                    word.set(words.get(i));
-                    forward(forwardObjs);
-                }
-            }
-        }
-
-        logger.info("Forwarded topic words each of " + topics + " topics");
-    }
-
-    /*
-     * For testing:
-     */
-
-    @VisibleForTesting
-    public void closeWithoutModelReset() throws HiveException {
-        // launch close(), but not forward & clear model
-        if (count == 0) {
-            this.model = null;
-            return;
-        }
-        if (miniBatchCount > 0) { // update for remaining samples
-            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
-        }
-        if (iterations > 1) {
-            runIterativeTraining(iterations);
-        }
-    }
-
-    @VisibleForTesting
-    double getProbability(String label, int k) {
-        return model.getProbability(label, k);
-    }
-
-    @VisibleForTesting
-    SortedMap<Float, List<String>> getTopicWords(int k) {
-        return model.getTopicWords(k);
-    }
-
-    @VisibleForTesting
-    float[] getTopicDistribution(@Nonnull String[] doc) {
-        return model.getTopicDistribution(doc);
-    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
new file mode 100644
index 0000000..cff076e
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NIOUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
+    private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
+
+    public static final int DEFAULT_TOPICS = 10;
+
+    // Options
+    protected int topics;
+    protected int iterations;
+    protected double eps;
+    protected int miniBatchSize;
+
+    protected String[][] miniBatch;
+    protected int miniBatchCount;
+
+    protected transient AbstractProbabilisticTopicModel model;
+
+    protected ListObjectInspector wordCountsOI;
+
+    // for iterations
+    protected NioStatefullSegment fileIO;
+    protected ByteBuffer inputBuf;
+
+    private float cumPerplexity;
+
+    public ProbabilisticTopicModelBaseUDTF() {
+        this.topics = DEFAULT_TOPICS;
+        this.iterations = 10;
+        this.eps = 1E-1d;
+        this.miniBatchSize = 128; // if 1, truly online setting
+    }
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
+        opts.addOption("eps", "epsilon", true,
+            "Check convergence based on the difference of perplexity [default: 1E-1]");
+        opts.addOption("s", "mini_batch_size", true,
+            "Repeat model updating per mini-batch [default: 128]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+        CommandLine cl = null;
+
+        if (argOIs.length >= 2) {
+            String rawArgs = HiveUtils.getConstString(argOIs[1]);
+            cl = parseOptions(rawArgs);
+            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+            if (iterations < 1) {
+                throw new UDFArgumentException(
+                    "'-iterations' must be greater than or equals to 1: " + iterations);
+            }
+            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
+            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+        }
+
+        return cl;
+    }
+
+    @Override
+    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length < 1) {
+            throw new UDFArgumentException(
+                "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
+        }
+
+        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+        processOptions(argOIs);
+
+        this.model = null;
+        this.miniBatch = new String[miniBatchSize][];
+        this.miniBatchCount = 0;
+        this.cumPerplexity = 0.f;
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        fieldNames.add("topic");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+        fieldNames.add("word");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+        fieldNames.add("score");
+        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+    }
+
+    protected abstract AbstractProbabilisticTopicModel createModel();
+
+    @Override
+    public void process(Object[] args) throws HiveException {
+        if (model == null) {
+            this.model = createModel();
+        }
+
+        final int length = wordCountsOI.getListLength(args[0]);
+        final String[] wordCounts = new String[length];
+        int j = 0;
+        for (int i = 0; i < length; i++) {
+            Object o = wordCountsOI.getListElement(args[0], i);
+            if (o == null) {
+                throw new HiveException("Given feature vector contains invalid elements");
+            }
+            String s = o.toString();
+            wordCounts[j] = s;
+            j++;
+        }
+        if (j == 0) {// avoid empty documents
+            return;
+        }
+
+        model.accumulateDocCount();;
+
+        update(wordCounts);
+
+        recordTrainSampleToTempFile(wordCounts);
+    }
+
+    protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
+            throws HiveException {
+        if (iterations == 1) {
+            return;
+        }
+
+        ByteBuffer buf = inputBuf;
+        NioStatefullSegment dst = fileIO;
+
+        if (buf == null) {
+            final File file;
+            try {
+                file = File.createTempFile("hivemall_topicmodel", ".sgmt");
+                file.deleteOnExit();
+                if (!file.canWrite()) {
+                    throw new UDFArgumentException("Cannot write a temporary file: "
+                            + file.getAbsolutePath());
+                }
+                logger.info("Record training samples to a file: " + file.getAbsolutePath());
+            } catch (IOException ioe) {
+                throw new UDFArgumentException(ioe);
+            } catch (Throwable e) {
+                throw new UDFArgumentException(e);
+            }
+            this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
+            this.fileIO = dst = new NioStatefullSegment(file, false);
+        }
+
+        // wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
+        int wcLengthTotal = 0;
+        for (String wc : wordCounts) {
+            if (wc == null) {
+                continue;
+            }
+            wcLengthTotal += wc.length();
+        }
+        int recordBytes = SizeOf.INT + SizeOf.INT * wordCounts.length + wcLengthTotal * SizeOf.CHAR;
+        int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
+
+        int remain = buf.remaining();
+        if (remain < requiredBytes) {
+            writeBuffer(buf, dst);
+        }
+
+        buf.putInt(recordBytes);
+        buf.putInt(wordCounts.length);
+        for (String wc : wordCounts) {
+            NIOUtils.putString(wc, buf);
+        }
+    }
+
+    private void update(@Nonnull final String[] wordCounts) {
+        miniBatch[miniBatchCount] = wordCounts;
+        miniBatchCount++;
+
+        if (miniBatchCount == miniBatchSize) {
+            train();
+        }
+    }
+
+    protected void train() {
+        if (miniBatchCount == 0) {
+            return;
+        }
+
+        model.train(miniBatch);
+
+        this.cumPerplexity += model.computePerplexity();
+
+        Arrays.fill(miniBatch, null); // clear
+        miniBatchCount = 0;
+    }
+
+    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
+            throws HiveException {
+        srcBuf.flip();
+        try {
+            dst.write(srcBuf);
+        } catch (IOException e) {
+            throw new HiveException("Exception causes while writing a buffer to file", e);
+        }
+        srcBuf.clear();
+    }
+
+    @Override
+    public void close() throws HiveException {
+        finalizeTraining();
+        forwardModel();
+        this.model = null;
+    }
+
+    @VisibleForTesting
+    void finalizeTraining() throws HiveException {
+        if (model.getDocCount() == 0L) {
+            this.model = null;
+            return;
+        }
+        if (miniBatchCount > 0) { // update for remaining samples
+            model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+        }
+        if (iterations > 1) {
+            runIterativeTraining(iterations);
+        }
+    }
+
+    protected final void runIterativeTraining(@Nonnegative final int iterations)
+            throws HiveException {
+        final ByteBuffer buf = this.inputBuf;
+        final NioStatefullSegment dst = this.fileIO;
+        assert (buf != null);
+        assert (dst != null);
+        final long numTrainingExamples = model.getDocCount();
+
+        long numTrain = numTrainingExamples / miniBatchSize;
+        if (numTrainingExamples % miniBatchSize != 0L) {
+            numTrain++;
+        }
+
+        final Reporter reporter = getReporter();
+        final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+            "hivemall.topicmodel.ProbabilisticTopicModel$Counter", "iteration");
+
+        try {
+            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+                if (buf.position() == 0) {
+                    return; // no training example
+                }
+                buf.flip();
+
+                int iter = 2;
+                float perplexity = cumPerplexity / numTrain;
+                float perplexityPrev;
+                for (; iter <= iterations; iter++) {
+                    perplexityPrev = perplexity;
+                    cumPerplexity = 0.f;
+
+                    reportProgress(reporter);
+                    setCounterValue(iterCounter, iter);
+
+                    while (buf.remaining() > 0) {
+                        int recordBytes = buf.getInt();
+                        assert (recordBytes > 0) : recordBytes;
+                        int wcLength = buf.getInt();
+                        final String[] wordCounts = new String[wcLength];
+                        for (int j = 0; j < wcLength; j++) {
+                            wordCounts[j] = NIOUtils.getString(buf);
+                        }
+                        update(wordCounts);
+                    }
+                    buf.rewind();
+
+                    // mean perplexity over `numTrain` mini-batches
+                    perplexity = cumPerplexity / numTrain;
+                    logger.info("Mean perplexity over mini-batches: " + perplexity);
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on memory (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total) ");
+            } else {// read training examples in the temporary file and invoke train for each example
+                // write training examples in buffer to a temporary file
+                if (buf.remaining() > 0) {
+                    writeBuffer(buf, dst);
+                }
+                try {
+                    dst.flush();
+                } catch (IOException e) {
+                    throw new HiveException("Failed to flush a file: "
+                            + dst.getFile().getAbsolutePath(), e);
+                }
+                if (logger.isInfoEnabled()) {
+                    File tmpFile = dst.getFile();
+                    logger.info("Wrote " + numTrainingExamples
+                            + " records to a temporary file for iterative training: "
+                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+                            + ")");
+                }
+
+                // run iterations
+                int iter = 2;
+                float perplexity = cumPerplexity / numTrain;
+                float perplexityPrev;
+                for (; iter <= iterations; iter++) {
+                    perplexityPrev = perplexity;
+                    cumPerplexity = 0.f;
+
+                    setCounterValue(iterCounter, iter);
+
+                    buf.clear();
+                    dst.resetPosition();
+                    while (true) {
+                        reportProgress(reporter);
+                        // TODO prefetch
+                        // writes training examples to a buffer in the temporary file
+                        final int bytesRead;
+                        try {
+                            bytesRead = dst.read(buf);
+                        } catch (IOException e) {
+                            throw new HiveException("Failed to read a file: "
+                                    + dst.getFile().getAbsolutePath(), e);
+                        }
+                        if (bytesRead == 0) { // reached file EOF
+                            break;
+                        }
+                        assert (bytesRead > 0) : bytesRead;
+
+                        // reads training examples from a buffer
+                        buf.flip();
+                        int remain = buf.remaining();
+                        if (remain < SizeOf.INT) {
+                            throw new HiveException("Illegal file format was detected");
+                        }
+                        while (remain >= SizeOf.INT) {
+                            int pos = buf.position();
+                            int recordBytes = buf.getInt();
+                            remain -= SizeOf.INT;
+                            if (remain < recordBytes) {
+                                buf.position(pos);
+                                break;
+                            }
+
+                            int wcLength = buf.getInt();
+                            final String[] wordCounts = new String[wcLength];
+                            for (int j = 0; j < wcLength; j++) {
+                                wordCounts[j] = NIOUtils.getString(buf);
+                            }
+                            update(wordCounts);
+
+                            remain -= recordBytes;
+                        }
+                        buf.compact();
+                    }
+
+                    // mean perplexity over `numTrain` mini-batches
+                    perplexity = cumPerplexity / numTrain;
+                    logger.info("Mean perplexity over mini-batches: " + perplexity);
+                    if (Math.abs(perplexityPrev - perplexity) < eps) {
+                        break;
+                    }
+                }
+                logger.info("Performed "
+                        + Math.min(iter, iterations)
+                        + " iterations of "
+                        + NumberUtils.formatNumber(numTrainingExamples)
+                        + " training examples on a secondary storage (thus "
+                        + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+                        + " training updates in total)");
+            }
+        } catch (Throwable e) {
+            throw new HiveException("Exception caused in the iterative training", e);
+        } finally {
+            // delete the temporary file and release resources
+            try {
+                dst.close(true);
+            } catch (IOException e) {
+                throw new HiveException("Failed to close a file: "
+                        + dst.getFile().getAbsolutePath(), e);
+            }
+            this.inputBuf = null;
+            this.fileIO = null;
+        }
+    }
+
+    protected void forwardModel() throws HiveException {
+        final IntWritable topicIdx = new IntWritable();
+        final Text word = new Text();
+        final FloatWritable score = new FloatWritable();
+
+        final Object[] forwardObjs = new Object[3];
+        forwardObjs[0] = topicIdx;
+        forwardObjs[1] = word;
+        forwardObjs[2] = score;
+
+        for (int k = 0; k < topics; k++) {
+            topicIdx.set(k);
+
+            final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
+            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+                score.set(e.getKey());
+                List<String> words = e.getValue();
+                for (int i = 0; i < words.size(); i++) {
+                    word.set(words.get(i));
+                    forward(forwardObjs);
+                }
+            }
+        }
+
+        logger.info("Forwarded topic words each of " + topics + " topics");
+    }
+
+    @VisibleForTesting
+    float getWordScore(String label, int k) {
+        return model.getWordScore(label, k);
+    }
+
+    @VisibleForTesting
+    SortedMap<Float, List<String>> getTopicWords(int k) {
+        return model.getTopicWords(k);
+    }
+
+    @VisibleForTesting
+    float[] getTopicDistribution(@Nonnull String[] doc) {
+        return model.getTopicDistribution(doc);
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
index 79be3a7..96bbe64 100644
--- a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -110,10 +110,10 @@ public class IncrementalPLSAModelTest {
         }
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
-            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+            model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
-            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+            model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
     }
 
     @Test
@@ -177,10 +177,10 @@ public class IncrementalPLSAModelTest {
         }
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
-            model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+            model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
-            model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+            model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
     }
 
     @Test

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index a934ba3..4cbb668 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -53,7 +53,7 @@ public class LDAUDTFTest {
         udtf.process(new Object[] {Arrays.asList(doc1)});
         udtf.process(new Object[] {Arrays.asList(doc2)});
 
-        udtf.closeWithoutModelReset();
+        udtf.finalizeTraining();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -92,10 +92,10 @@ public class LDAUDTFTest {
 
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
-            udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1));
+            udtf.getWordScore("vegetables", k1) > udtf.getWordScore("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
-            udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2));
+            udtf.getWordScore("avocados", k2) > udtf.getWordScore("healthy", k2));
     }
 
     @Test
@@ -106,7 +106,7 @@ public class LDAUDTFTest {
                 ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                 ObjectInspectorUtils.getConstantObjectInspector(
                     PrimitiveObjectInspectorFactory.javaStringObjectInspector,
-                    "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")};
+                    "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3  -mini_batch_size 1")};
 
         udtf.initialize(argOIs);
 
@@ -116,7 +116,7 @@ public class LDAUDTFTest {
         udtf.process(new Object[] {Arrays.asList(doc1)});
         udtf.process(new Object[] {Arrays.asList(doc2)});
 
-        udtf.closeWithoutModelReset();
+        udtf.finalizeTraining();
 
         SortedMap<Float, List<String>> topicWords;
 
@@ -155,10 +155,10 @@ public class LDAUDTFTest {
 
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `野菜` SHOULD be more suitable topic word than `インフルエンザ` in the topic",
-            udtf.getLambda("野菜", k1) > udtf.getLambda("インフルエンザ", k1));
+            udtf.getWordScore("野菜", k1) > udtf.getWordScore("インフルエンザ", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `アボカド` SHOULD be more suitable topic word than `健康` in the topic",
-            udtf.getLambda("アボカド", k2) > udtf.getLambda("健康", k2));
+            udtf.getWordScore("アボカド", k2) > udtf.getWordScore("健康", k2));
     }
 
     private static void println(String msg) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
index 5b0a8c2..68f251a 100644
--- a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
@@ -108,10 +108,10 @@ public class OnlineLDAModelTest {
         }
         Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
                 + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
-            model.getLambda("vegetables", k1) > model.getLambda("flu", k1));
+            model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
         Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
                 + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
-            model.getLambda("avocados", k2) > model.getLambda("healthy", k2));
+            model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
     }
 
     @Test