You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/27 06:45:01 UTC

[3/3] incubator-hivemall git commit: Applied refactoring for topicmodel module

Applied refactoring for topicmodel module

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9f01ebf2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9f01ebf2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9f01ebf2

Branch: refs/heads/master
Commit: 9f01ebf20c74559be8a50d459103118a51c229bf
Parents: 0495ffa
Author: Makoto Yui <my...@apache.org>
Authored: Tue Jun 27 15:44:31 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 15:44:31 2017 +0900

----------------------------------------------------------------------
 .../AbstractProbabilisticTopicModel.java        | 26 +++++++++++++-------
 .../topicmodel/IncrementalPLSAModel.java        | 16 ++++++------
 .../main/java/hivemall/topicmodel/LDAUDTF.java  |  5 +---
 .../hivemall/topicmodel/OnlineLDAModel.java     | 18 +++++++-------
 .../main/java/hivemall/topicmodel/PLSAUDTF.java |  5 +---
 .../ProbabilisticTopicModelBaseUDTF.java        | 25 ++++++++++++-------
 6 files changed, 52 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
index 3c097e2..1b7f3e8 100644
--- a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -21,9 +21,14 @@ package hivemall.topicmodel;
 import hivemall.annotations.VisibleForTesting;
 import hivemall.model.FeatureValue;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
-import java.util.*;
 
 public abstract class AbstractProbabilisticTopicModel {
 
@@ -31,6 +36,7 @@ public abstract class AbstractProbabilisticTopicModel {
     protected final int _K;
 
     // total number of documents
+    @Nonnegative
     protected long _D;
 
     // for mini-batch
@@ -38,7 +44,7 @@ public abstract class AbstractProbabilisticTopicModel {
     protected final List<Map<String, Float>> _miniBatchDocs;
     protected int _miniBatchSize;
 
-    public AbstractProbabilisticTopicModel(int K) {
+    public AbstractProbabilisticTopicModel(@Nonnegative int K) {
         this._K = K;
         this._D = 0L;
         this._miniBatchDocs = new ArrayList<Map<String, Float>>();
@@ -73,26 +79,28 @@ public abstract class AbstractProbabilisticTopicModel {
         }
     }
 
-    public void accumulateDocCount() {
+    protected void accumulateDocCount() {
         this._D += 1;
     }
 
-    public long getDocCount() {
+    @Nonnegative
+    protected long getDocCount() {
         return _D;
     }
 
-    public abstract void train(@Nonnull final String[][] miniBatch);
+    protected abstract void train(@Nonnull final String[][] miniBatch);
 
-    public abstract float computePerplexity();
+    protected abstract float computePerplexity();
 
     @Nonnull
-    public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+    protected abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
 
     @Nonnull
-    public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+    protected abstract float[] getTopicDistribution(@Nonnull final String[] doc);
 
     @VisibleForTesting
     abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
 
-    public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+    protected abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic,
+            final float score);
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index b99e670..6419664 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,7 +20,6 @@ package hivemall.topicmodel;
 
 import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
 import static hivemall.utils.math.MathUtils.l1normalize;
-
 import hivemall.annotations.VisibleForTesting;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
@@ -59,9 +58,10 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
 
     // optimized in the M step
-    @Nonnull
     private List<float[]> _p_dz; // P(z|d) probability of topics for documents
-    private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+    @Nonnull
+    private final Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
 
     public IncrementalPLSAModel(int K, float alpha, double delta) {
         super(K);
@@ -74,7 +74,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         this._p_zw = new HashMap<String, float[]>();
     }
 
-    public void train(@Nonnull final String[][] miniBatch) {
+    protected void train(@Nonnull final String[][] miniBatch) {
         initMiniBatch(miniBatch, _miniBatchDocs);
 
         this._miniBatchSize = _miniBatchDocs.size();
@@ -211,7 +211,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         return (diff / _K) < _delta;
     }
 
-    public float computePerplexity() {
+    protected float computePerplexity() {
         double numer = 0.d;
         double denom = 0.d;
 
@@ -241,7 +241,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     }
 
     @Nonnull
-    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
         final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
             Collections.reverseOrder());
 
@@ -261,7 +261,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
     }
 
     @Nonnull
-    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+    protected float[] getTopicDistribution(@Nonnull final String[] doc) {
         train(new String[][] {doc});
         return _p_dz.get(0);
     }
@@ -271,7 +271,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
         return _p_zw.get(w)[z];
     }
 
-    public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+    protected void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
         float[] prob_label = _p_zw.get(w);
         if (prob_label == null) {
             prob_label = newRandomFloatArray(_K, _rnd);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 41386a4..9bac908 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 
 @Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
-    private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
 
     public static final double DEFAULT_DELTA = 1E-3d;
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 4a7531c..6a8d6db 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -38,6 +38,9 @@ import org.apache.commons.math3.special.Gamma;
 
 public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
 
+    private static final double SHAPE = 100.d;
+    private static final double SCALE = 1.d / SHAPE;
+
     // ---------------------------------
     // HyperParameters
 
@@ -72,7 +75,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     private final boolean _isAutoD;
 
     // parameters
-    @Nonnull
     private List<Map<String, float[]>> _phi;
     private float[][] _gamma;
     @Nonnull
@@ -81,8 +83,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     // random number generator
     @Nonnull
     private final GammaDistribution _gd;
-    private static final double SHAPE = 100.d;
-    private static final double SCALE = 1.d / SHAPE;
 
     // for computing perplexity
     private float _docRatio = 1.f;
@@ -121,7 +121,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Override
-    public void accumulateDocCount() {
+    protected void accumulateDocCount() {
         /*
          * In a truly online setting, total number of documents equals to the number of documents that have ever seen.
          * In that case, users need to manually set the current max number of documents via this method.
@@ -133,7 +133,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
         }
     }
 
-    public void train(@Nonnull final String[][] miniBatch) {
+    protected void train(@Nonnull final String[][] miniBatch) {
         preprocessMiniBatch(miniBatch);
 
         initParams(true);
@@ -341,7 +341,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     /**
      * Calculate approximate perplexity for the current mini-batch.
      */
-    public float computePerplexity() {
+    protected float computePerplexity() {
         double bound = computeApproxBound();
         double perWordBound = bound / (_docRatio * _valueSum);
         return (float) Math.exp(-1.d * perWordBound);
@@ -449,7 +449,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
         return lambda_label[k];
     }
 
-    public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
+    protected void setWordScore(@Nonnull final String label, @Nonnegative final int k,
             final float lambda_k) {
         float[] lambda_label = _lambda.get(label);
         if (lambda_label == null) {
@@ -460,7 +460,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Nonnull
-    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
+    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
         return getTopicWords(k, _lambda.keySet().size());
     }
 
@@ -501,7 +501,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
     }
 
     @Nonnull
-    public float[] getTopicDistribution(@Nonnull final String[] doc) {
+    protected float[] getTopicDistribution(@Nonnull final String[] doc) {
         preprocessMiniBatch(new String[][] {doc});
 
         initParams(false);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index e1d8797..9c5a0ea 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 
 @Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
         + " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
-    private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+public final class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
 
     public static final float DEFAULT_ALPHA = 0.5f;
     public static final double DEFAULT_DELTA = 1E-3d;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
index cff076e..c3dab89 100644
--- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -27,6 +27,19 @@ import hivemall.utils.io.NioStatefullSegment;
 import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.lang.Primitives;
 import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
 import org.apache.commons.logging.Log;
@@ -44,13 +57,6 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.Counters;
 import org.apache.hadoop.mapred.Reporter;
 
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-
 public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
 
@@ -143,6 +149,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
         return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
     }
 
+    @Nonnull
     protected abstract AbstractProbabilisticTopicModel createModel();
 
     @Override
@@ -157,7 +164,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
         for (int i = 0; i < length; i++) {
             Object o = wordCountsOI.getListElement(args[0], i);
             if (o == null) {
-                throw new HiveException("Given feature vector contains invalid elements");
+                throw new HiveException("Given feature vector contains invalid null elements");
             }
             String s = o.toString();
             wordCounts[j] = s;
@@ -167,7 +174,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
             return;
         }
 
-        model.accumulateDocCount();;
+        model.accumulateDocCount();
 
         update(wordCounts);