You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/27 06:45:01 UTC
[3/3] incubator-hivemall git commit: Applied refactoring for
topicmodel module
Applied refactoring for topicmodel module
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9f01ebf2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9f01ebf2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9f01ebf2
Branch: refs/heads/master
Commit: 9f01ebf20c74559be8a50d459103118a51c229bf
Parents: 0495ffa
Author: Makoto Yui <my...@apache.org>
Authored: Tue Jun 27 15:44:31 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 15:44:31 2017 +0900
----------------------------------------------------------------------
.../AbstractProbabilisticTopicModel.java | 26 +++++++++++++-------
.../topicmodel/IncrementalPLSAModel.java | 16 ++++++------
.../main/java/hivemall/topicmodel/LDAUDTF.java | 5 +---
.../hivemall/topicmodel/OnlineLDAModel.java | 18 +++++++-------
.../main/java/hivemall/topicmodel/PLSAUDTF.java | 5 +---
.../ProbabilisticTopicModelBaseUDTF.java | 25 ++++++++++++-------
6 files changed, 52 insertions(+), 43 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
index 3c097e2..1b7f3e8 100644
--- a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -21,9 +21,14 @@ package hivemall.topicmodel;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
-import java.util.*;
public abstract class AbstractProbabilisticTopicModel {
@@ -31,6 +36,7 @@ public abstract class AbstractProbabilisticTopicModel {
protected final int _K;
// total number of documents
+ @Nonnegative
protected long _D;
// for mini-batch
@@ -38,7 +44,7 @@ public abstract class AbstractProbabilisticTopicModel {
protected final List<Map<String, Float>> _miniBatchDocs;
protected int _miniBatchSize;
- public AbstractProbabilisticTopicModel(int K) {
+ public AbstractProbabilisticTopicModel(@Nonnegative int K) {
this._K = K;
this._D = 0L;
this._miniBatchDocs = new ArrayList<Map<String, Float>>();
@@ -73,26 +79,28 @@ public abstract class AbstractProbabilisticTopicModel {
}
}
- public void accumulateDocCount() {
+ protected void accumulateDocCount() {
this._D += 1;
}
- public long getDocCount() {
+ @Nonnegative
+ protected long getDocCount() {
return _D;
}
- public abstract void train(@Nonnull final String[][] miniBatch);
+ protected abstract void train(@Nonnull final String[][] miniBatch);
- public abstract float computePerplexity();
+ protected abstract float computePerplexity();
@Nonnull
- public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+ protected abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
@Nonnull
- public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+ protected abstract float[] getTopicDistribution(@Nonnull final String[] doc);
@VisibleForTesting
abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
- public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+ protected abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic,
+ final float score);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index b99e670..6419664 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,7 +20,6 @@ package hivemall.topicmodel;
import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
import static hivemall.utils.math.MathUtils.l1normalize;
-
import hivemall.annotations.VisibleForTesting;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
@@ -59,9 +58,10 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
// optimized in the M step
- @Nonnull
private List<float[]> _p_dz; // P(z|d) probability of topics for documents
- private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+ @Nonnull
+ private final Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
public IncrementalPLSAModel(int K, float alpha, double delta) {
super(K);
@@ -74,7 +74,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
this._p_zw = new HashMap<String, float[]>();
}
- public void train(@Nonnull final String[][] miniBatch) {
+ protected void train(@Nonnull final String[][] miniBatch) {
initMiniBatch(miniBatch, _miniBatchDocs);
this._miniBatchSize = _miniBatchDocs.size();
@@ -211,7 +211,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
return (diff / _K) < _delta;
}
- public float computePerplexity() {
+ protected float computePerplexity() {
double numer = 0.d;
double denom = 0.d;
@@ -241,7 +241,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
}
@Nonnull
- public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+ protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
Collections.reverseOrder());
@@ -261,7 +261,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
}
@Nonnull
- public float[] getTopicDistribution(@Nonnull final String[] doc) {
+ protected float[] getTopicDistribution(@Nonnull final String[] doc) {
train(new String[][] {doc});
return _p_dz.get(0);
}
@@ -271,7 +271,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
return _p_zw.get(w)[z];
}
- public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+ protected void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
float[] prob_label = _p_zw.get(w);
if (prob_label == null) {
prob_label = newRandomFloatArray(_K, _rnd);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 41386a4..9bac908 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
- private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
public static final double DEFAULT_DELTA = 1E-3d;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 4a7531c..6a8d6db 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -38,6 +38,9 @@ import org.apache.commons.math3.special.Gamma;
public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
+ private static final double SHAPE = 100.d;
+ private static final double SCALE = 1.d / SHAPE;
+
// ---------------------------------
// HyperParameters
@@ -72,7 +75,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
private final boolean _isAutoD;
// parameters
- @Nonnull
private List<Map<String, float[]>> _phi;
private float[][] _gamma;
@Nonnull
@@ -81,8 +83,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
// random number generator
@Nonnull
private final GammaDistribution _gd;
- private static final double SHAPE = 100.d;
- private static final double SCALE = 1.d / SHAPE;
// for computing perplexity
private float _docRatio = 1.f;
@@ -121,7 +121,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Override
- public void accumulateDocCount() {
+ protected void accumulateDocCount() {
/*
* In a truly online setting, total number of documents equals to the number of documents that have ever seen.
* In that case, users need to manually set the current max number of documents via this method.
@@ -133,7 +133,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
}
- public void train(@Nonnull final String[][] miniBatch) {
+ protected void train(@Nonnull final String[][] miniBatch) {
preprocessMiniBatch(miniBatch);
initParams(true);
@@ -341,7 +341,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
/**
* Calculate approximate perplexity for the current mini-batch.
*/
- public float computePerplexity() {
+ protected float computePerplexity() {
double bound = computeApproxBound();
double perWordBound = bound / (_docRatio * _valueSum);
return (float) Math.exp(-1.d * perWordBound);
@@ -449,7 +449,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
return lambda_label[k];
}
- public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
+ protected void setWordScore(@Nonnull final String label, @Nonnegative final int k,
final float lambda_k) {
float[] lambda_label = _lambda.get(label);
if (lambda_label == null) {
@@ -460,7 +460,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Nonnull
- public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
+ protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
return getTopicWords(k, _lambda.keySet().size());
}
@@ -501,7 +501,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Nonnull
- public float[] getTopicDistribution(@Nonnull final String[] doc) {
+ protected float[] getTopicDistribution(@Nonnull final String[] doc) {
preprocessMiniBatch(new String[][] {doc});
initParams(false);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index e1d8797..9c5a0ea 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
- private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+public final class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
public static final float DEFAULT_ALPHA = 0.5f;
public static final double DEFAULT_DELTA = 1E-3d;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
index cff076e..c3dab89 100644
--- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -27,6 +27,19 @@ import hivemall.utils.io.NioStatefullSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
@@ -44,13 +57,6 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-
public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
@@ -143,6 +149,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
+ @Nonnull
protected abstract AbstractProbabilisticTopicModel createModel();
@Override
@@ -157,7 +164,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
for (int i = 0; i < length; i++) {
Object o = wordCountsOI.getListElement(args[0], i);
if (o == null) {
- throw new HiveException("Given feature vector contains invalid elements");
+ throw new HiveException("Given feature vector contains invalid null elements");
}
String s = o.toString();
wordCounts[j] = s;
@@ -167,7 +174,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
return;
}
- model.accumulateDocCount();;
+ model.accumulateDocCount();
update(wordCounts);