You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/20 12:02:40 UTC
[5/5] incubator-hivemall git commit: Refactored LDA implementation
Refactored LDA implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e4e1531e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e4e1531e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e4e1531e
Branch: refs/heads/master
Commit: e4e1531e16de51bb934910ec7269895ca51fab4f
Parents: 9669c9d
Author: myui <yu...@gmail.com>
Authored: Thu Apr 20 21:01:36 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Thu Apr 20 21:01:36 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/topicmodel/LDAUDTF.java | 7 +-
.../hivemall/topicmodel/OnlineLDAModel.java | 161 +++++++++----------
.../java/hivemall/utils/math/MathUtils.java | 84 +++++++---
3 files changed, 148 insertions(+), 104 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 91ee7a2..9aa15e2 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -196,8 +196,8 @@ public class LDAUDTF extends UDTFWithOptions {
initModel();
}
- int length = wordCountsOI.getListLength(args[0]);
- String[] wordCounts = new String[length];
+ final int length = wordCountsOI.getListLength(args[0]);
+ final String[] wordCounts = new String[length];
int j = 0;
for (int i = 0; i < length; i++) {
Object o = wordCountsOI.getListElement(args[0], i);
@@ -208,6 +208,9 @@ public class LDAUDTF extends UDTFWithOptions {
wordCounts[j] = s;
j++;
}
+ if (j == 0) {// avoid empty documents
+ return;
+ }
count++;
if (isAutoD) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 890adac..8fef10c 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -90,12 +90,12 @@ public final class OnlineLDAModel {
// for mini-batch
@Nonnull
- private final List<Map<String, Float>> _miniBatchMap;
+ private final List<Map<String, Float>> _miniBatchDocs;
private int _miniBatchSize;
// for computing perplexity
private float _docRatio = 1.f;
- private long _wordCount = 0L;
+ private double _valueSum = 0.d;
public OnlineLDAModel(int K, float alpha, double delta) { // for E step only instantiation
this(K, alpha, 1 / 20.f, -1L, 1020, 0.7, delta);
@@ -125,15 +125,13 @@ public final class OnlineLDAModel {
// initialize the parameters
this._lambda = new HashMap<String, float[]>(100);
- this._miniBatchMap = new ArrayList<Map<String, Float>>();
+ this._miniBatchDocs = new ArrayList<Map<String, Float>>();
}
/**
- * In a truly online setting, total number of documents corresponds to the number of documents
- * that have ever seen. In that case, users need to manually set the current max number of documents
- * via this method.
- * Note that, since the same set of documents could be repeatedly passed to `train()`,
- * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
+ * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to
+ * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to
+ * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
*/
public void setNumTotalDocs(@Nonnegative long D) {
this._D = D;
@@ -161,34 +159,35 @@ public final class OnlineLDAModel {
}
private void preprocessMiniBatch(@Nonnull final String[][] miniBatch) {
- initMiniBatchMap(miniBatch, _miniBatchMap);
+ initMiniBatch(miniBatch, _miniBatchDocs);
- this._miniBatchSize = _miniBatchMap.size();
+ this._miniBatchSize = _miniBatchDocs.size();
// accumulate the number of words for each documents
- this._wordCount = 0L;
+ double valueSum = 0.d;
for (int d = 0; d < _miniBatchSize; d++) {
- for (float n : _miniBatchMap.get(d).values()) {
- this._wordCount += n;
+ for (Float n : _miniBatchDocs.get(d).values()) {
+ valueSum += n.floatValue();
}
}
+ this._valueSum = valueSum;
this._docRatio = (float) ((double) _D / _miniBatchSize);
}
- private static void initMiniBatchMap(@Nonnull final String[][] miniBatch,
- @Nonnull final List<Map<String, Float>> map) {
- map.clear();
+ private static void initMiniBatch(@Nonnull final String[][] miniBatch,
+ @Nonnull final List<Map<String, Float>> docs) {
+ docs.clear();
final FeatureValue probe = new FeatureValue();
// parse document
for (final String[] e : miniBatch) {
- if (e == null) {
+ if (e == null || e.length == 0) {
continue;
}
- final Map<String, Float> docMap = new HashMap<String, Float>();
+ final Map<String, Float> doc = new HashMap<String, Float>();
// parse features
for (String fv : e) {
@@ -198,10 +197,10 @@ public final class OnlineLDAModel {
FeatureValue.parseFeatureAsString(fv, probe);
String label = probe.getFeatureAsString();
float value = probe.getValueAsFloat();
- docMap.put(label, value);
+ doc.put(label, Float.valueOf(value));
}
- map.add(docMap);
+ docs.add(doc);
}
}
@@ -218,7 +217,7 @@ public final class OnlineLDAModel {
final Map<String, float[]> phi_d = new HashMap<String, float[]>();
phi.add(phi_d);
- for (final String label : _miniBatchMap.get(d).keySet()) {
+ for (final String label : _miniBatchDocs.get(d).keySet()) {
phi_d.put(label, new float[_K]);
if (!_lambda.containsKey(label)) { // lambda for newly observed word
_lambda.put(label, ArrayUtils.newRandomFloatArray(_K, _gd));
@@ -233,19 +232,19 @@ public final class OnlineLDAModel {
private void eStep() {
// since lambda is invariant in the expectation step,
// `digamma`s of lambda values for Elogbeta are pre-computed
- final float[] lambdaSum = new float[_K];
+ final double[] lambdaSum = new double[_K];
final Map<String, float[]> digamma_lambda = new HashMap<String, float[]>();
for (Map.Entry<String, float[]> e : _lambda.entrySet()) {
String label = e.getKey();
float[] lambda_label = e.getValue();
// for digamma(lambdaSum)
- MathUtils.add(lambdaSum, lambda_label, _K);
+ MathUtils.add(lambda_label, lambdaSum, _K);
digamma_lambda.put(label, MathUtils.digamma(lambda_label));
}
- final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
// for each of mini-batch documents, update gamma until convergence
float[] gamma_d, gammaPrev_d;
Map<String, float[]> eLogBeta_d;
@@ -265,11 +264,11 @@ public final class OnlineLDAModel {
@Nonnull
private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative final int d,
@Nonnull final Map<String, float[]> digamma_lambda,
- @Nonnull final float[] digamma_lambdaSum) {
- // Dirichlet expectation (2d) for lambda
- final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>();
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ @Nonnull final double[] digamma_lambdaSum) {
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
+ // Dirichlet expectation (2d) for lambda
+ final Map<String, float[]> eLogBeta_d = new HashMap<String, float[]>(doc.size());
for (final String label : doc.keySet()) {
float[] eLogBeta_label = eLogBeta_d.get(label);
if (eLogBeta_label == null) {
@@ -278,7 +277,7 @@ public final class OnlineLDAModel {
}
final float[] digamma_lambda_label = digamma_lambda.get(label);
for (int k = 0; k < _K; k++) {
- eLogBeta_label[k] = digamma_lambda_label[k] - digamma_lambdaSum[k];
+ eLogBeta_label[k] = (float) (digamma_lambda_label[k] - digamma_lambdaSum[k]);
}
}
@@ -288,28 +287,27 @@ public final class OnlineLDAModel {
private void updatePhiPerDoc(@Nonnegative final int d,
@Nonnull final Map<String, float[]> eLogBeta_d) {
// Dirichlet expectation (2d) for gamma
- final float[] eLogTheta_d = new float[_K];
final float[] gamma_d = _gamma[d];
- final float digamma_gammaSum_d = (float) Gamma.digamma(MathUtils.sum(gamma_d));
+ final double digamma_gammaSum_d = Gamma.digamma(MathUtils.sum(gamma_d));
+ final double[] eLogTheta_d = new double[_K];
for (int k = 0; k < _K; k++) {
- eLogTheta_d[k] = (float) Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
+ eLogTheta_d[k] = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
}
// updating phi w/ normalization
final Map<String, float[]> phi_d = _phi.get(d);
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
for (String label : doc.keySet()) {
final float[] phi_label = phi_d.get(label);
final float[] eLogBeta_label = eLogBeta_d.get(label);
- float normalizer = 0.f;
+ double normalizer = 0.d;
for (int k = 0; k < _K; k++) {
float phiVal = (float) Math.exp(eLogBeta_label[k] + eLogTheta_d[k]) + 1E-20f;
phi_label[k] = phiVal;
normalizer += phiVal;
}
- // normalize
for (int k = 0; k < _K; k++) {
phi_label[k] /= normalizer;
}
@@ -317,7 +315,7 @@ public final class OnlineLDAModel {
}
private void updateGammaPerDoc(@Nonnegative final int d) {
- final Map<String, Float> doc = _miniBatchMap.get(d);
+ final Map<String, Float> doc = _miniBatchDocs.get(d);
final Map<String, float[]> phi_d = _phi.get(d);
final float[] gamma_d = _gamma[d];
@@ -326,7 +324,7 @@ public final class OnlineLDAModel {
}
for (Map.Entry<String, Float> e : doc.entrySet()) {
final float[] phi_label = phi_d.get(e.getKey());
- final float val = e.getValue();
+ final float val = e.getValue().floatValue();
for (int k = 0; k < _K; k++) {
gamma_d[k] += phi_label[k] * val;
}
@@ -347,7 +345,7 @@ public final class OnlineLDAModel {
final Map<String, float[]> lambdaTilde = new HashMap<String, float[]>();
for (int d = 0; d < _miniBatchSize; d++) {
final Map<String, float[]> phi_d = _phi.get(d);
- for (String label : _miniBatchMap.get(d).keySet()) {
+ for (String label : _miniBatchDocs.get(d).keySet()) {
float[] lambdaTilde_label = lambdaTilde.get(label);
if (lambdaTilde_label == null) {
lambdaTilde_label = ArrayUtils.newFloatArray(_K, _eta);
@@ -382,73 +380,67 @@ public final class OnlineLDAModel {
* Calculate approximate perplexity for the current mini-batch.
*/
public float computePerplexity() {
- float bound = computeApproxBound();
- float perWordBound = bound / (_docRatio * _wordCount);
- return (float) Math.exp(-1.f * perWordBound);
+ double bound = computeApproxBound();
+ double perWordBound = bound / (_docRatio * _valueSum);
+ return (float) Math.exp(-1.d * perWordBound);
}
/**
* Estimates the variational bound over all documents using only the documents passed as mini-batch.
*/
- private float computeApproxBound() {
+ private double computeApproxBound() {
// prepare
- final float[] gammaSum = new float[_miniBatchSize];
+ final double[] gammaSum = new double[_miniBatchSize];
for (int d = 0; d < _miniBatchSize; d++) {
gammaSum[d] = MathUtils.sum(_gamma[d]);
}
- final float[] digamma_gammaSum = MathUtils.digamma(gammaSum);
+ final double[] digamma_gammaSum = MathUtils.digamma(gammaSum);
- final float[] lambdaSum = new float[_K];
+ final double[] lambdaSum = new double[_K];
for (float[] lambda_label : _lambda.values()) {
- MathUtils.add(lambdaSum, lambda_label, _K);
+ MathUtils.add(lambda_label, lambdaSum, _K);
}
- final float[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
+ final double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
- final float logGamma_alpha = (float) Gamma.logGamma(_alpha);
- final float logGamma_alphaSum = (float) Gamma.logGamma(_K * _alpha);
+ final double logGamma_alpha = Gamma.logGamma(_alpha);
+ final double logGamma_alphaSum = Gamma.logGamma(_K * _alpha);
- float score = 0.f;
+ double score = 0.d;
for (int d = 0; d < _miniBatchSize; d++) {
- final float digamma_gammaSum_d = digamma_gammaSum[d];
+ final double digamma_gammaSum_d = digamma_gammaSum[d];
+ final float[] gamma_d = _gamma[d];
// E[log p(doc | theta, beta)]
- for (Map.Entry<String, Float> e : _miniBatchMap.get(d).entrySet()) {
+ for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) {
final float[] lambda_label = _lambda.get(e.getKey());
// logsumexp( Elogthetad + Elogbetad )
- final float[] temp = new float[_K];
- float max = Float.MIN_VALUE;
+ final double[] temp = new double[_K];
+ double max = Double.MIN_VALUE;
for (int k = 0; k < _K; k++) {
- final float eLogTheta_dk = (float) Gamma.digamma(_gamma[d][k])
- - digamma_gammaSum_d;
- final float eLogBeta_kw = (float) Gamma.digamma(lambda_label[k])
- - digamma_lambdaSum[k];
-
- temp[k] = eLogTheta_dk + eLogBeta_kw;
- if (temp[k] > max) {
- max = temp[k];
+ double eLogTheta_dk = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
+ double eLogBeta_kw = Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k];
+ final double tempK = eLogTheta_dk + eLogBeta_kw;
+ if (tempK > max) {
+ max = tempK;
}
+ temp[k] = tempK;
}
- float logsumexp = 0.f;
- for (int k = 0; k < _K; k++) {
- logsumexp += (float) Math.exp(temp[k] - max);
- }
- logsumexp = max + (float) Math.log(logsumexp);
+ double logsumexp = MathUtils.logsumexp(temp, max);
// sum( word count * logsumexp(...) )
- score += e.getValue() * logsumexp;
+ score += e.getValue().floatValue() * logsumexp;
}
// E[log p(theta | alpha) - log q(theta | gamma)]
for (int k = 0; k < _K; k++) {
- final float gamma_dk = _gamma[d][k];
+ float gamma_dk = gamma_d[k];
// sum( (alpha - gammad) * Elogthetad )
- score += (_alpha - gamma_dk)
- * ((float) Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
+ score += (_alpha - gamma_dk) * (Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
// sum( gammaln(gammad) - gammaln(alpha) )
- score += (float) Gamma.logGamma(gamma_dk) - logGamma_alpha;
+ score += Gamma.logGamma(gamma_dk) - logGamma_alpha;
}
score += logGamma_alphaSum; // gammaln(sum(alpha))
score -= Gamma.logGamma(gammaSum[d]); // gammaln(sum(gammad))
@@ -458,25 +450,25 @@ public final class OnlineLDAModel {
// (i.e., online setting); likelihood should be always roughly on the same scale
score *= _docRatio;
- final float logGamma_eta = (float) Gamma.logGamma(_eta);
- final float logGamma_etaSum = (float) Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta
+ final double logGamma_eta = Gamma.logGamma(_eta);
+ final double logGamma_etaSum = Gamma.logGamma(_eta * _lambda.size()); // vocabulary size * eta
// E[log p(beta | eta) - log q (beta | lambda)]
- for (float[] lambda_label : _lambda.values()) {
+ for (final float[] lambda_label : _lambda.values()) {
for (int k = 0; k < _K; k++) {
- final float lambda_k = lambda_label[k];
+ float lambda_label_k = lambda_label[k];
// sum( (eta - lambda) * Elogbeta )
- score += (_eta - lambda_k)
- * (float) (Gamma.digamma(lambda_k) - digamma_lambdaSum[k]);
+ score += (_eta - lambda_label_k)
+ * (Gamma.digamma(lambda_label_k) - digamma_lambdaSum[k]);
// sum( gammaln(lambda) - gammaln(eta) )
- score += (float) Gamma.logGamma(lambda_k) - logGamma_eta;
+ score += Gamma.logGamma(lambda_label_k) - logGamma_eta;
}
}
for (int k = 0; k < _K; k++) {
// sum( gammaln(etaSum) - gammaln( lambdaSum_k )
- score += logGamma_etaSum - (float) Gamma.logGamma(lambdaSum[k]);
+ score += logGamma_etaSum - Gamma.logGamma(lambdaSum[k]);
}
return score;
@@ -513,7 +505,7 @@ public final class OnlineLDAModel {
@Nonnull
public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k,
@Nonnegative int topN) {
- float lambdaSum = 0.f;
+ double lambdaSum = 0.d;
final SortedMap<Float, List<String>> sortedLambda = new TreeMap<Float, List<String>>(
Collections.reverseOrder());
@@ -535,7 +527,8 @@ public final class OnlineLDAModel {
topN = Math.min(topN, _lambda.keySet().size());
int tt = 0;
for (Map.Entry<Float, List<String>> e : sortedLambda.entrySet()) {
- ret.put(e.getKey() / lambdaSum, e.getValue());
+ float key = (float) (e.getKey().floatValue() / lambdaSum);
+ ret.put(Float.valueOf(key), e.getValue());
if (++tt == topN) {
break;
@@ -556,9 +549,9 @@ public final class OnlineLDAModel {
// normalize topic distribution
final float[] topicDistr = new float[_K];
final float[] gamma0 = _gamma[0];
- final float gammaSum = MathUtils.sum(gamma0);
+ final double gammaSum = MathUtils.sum(gamma0);
for (int k = 0; k < _K; k++) {
- topicDistr[k] = gamma0[k] / gammaSum;
+ topicDistr[k] = (float) (gamma0[k] / gammaSum);
}
return topicDistr;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e4e1531e/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 7fdea55..71d0270 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -314,44 +314,92 @@ public final class MathUtils {
return perm;
}
- public static float sum(@Nullable final float[] a) {
- if (a == null) {
- return 0.f;
+ public static double sum(@Nullable final float[] arr) {
+ if (arr == null) {
+ return 0.d;
}
- float sum = 0.f;
- for (float v : a) {
+ double sum = 0.d;
+ for (float v : arr) {
sum += v;
}
return sum;
}
- public static float sum(@Nullable final float[] a, @Nonnegative final int size) {
- if (a == null) {
- return 0.f;
- }
-
- float sum = 0.f;
+ public static void add(@Nonnull final float[] src, @Nonnull final float[] dst, final int size) {
for (int i = 0; i < size; i++) {
- sum += a[i];
+ dst[i] += src[i];
}
- return sum;
}
- public static void add(@Nonnull final float[] dst, @Nonnull final float[] toAdd, final int size) {
+ public static void add(@Nonnull final float[] src, @Nonnull final double[] dst, final int size) {
for (int i = 0; i < size; i++) {
- dst[i] += toAdd[i];
+ dst[i] += src[i];
}
}
@Nonnull
- public static float[] digamma(@Nonnull final float[] a) {
- final int k = a.length;
+ public static float[] digamma(@Nonnull final float[] arr) {
+ final int k = arr.length;
final float[] ret = new float[k];
for (int i = 0; i < k; i++) {
- ret[i] = (float) Gamma.digamma(a[i]);
+ ret[i] = (float) Gamma.digamma(arr[i]);
}
return ret;
}
+ @Nonnull
+ public static double[] digamma(@Nonnull final double[] arr) {
+ final int k = arr.length;
+ final double[] ret = new double[k];
+ for (int i = 0; i < k; i++) {
+ ret[i] = Gamma.digamma(arr[i]);
+ }
+ return ret;
+ }
+
+ public static float logsumexp(@Nonnull final float[] arr) {
+ if (arr.length == 0) {
+ return 0.f;
+ }
+ float max = 0.f;
+ for (final float v : arr) {
+ if (v > max) {
+ max = v;
+ }
+ }
+ return logsumexp(arr, max);
+ }
+
+ public static float logsumexp(@Nonnull final float[] arr, final float max) {
+ double logsumexp = 0.d;
+ for (final float v : arr) {
+ logsumexp += Math.exp(v - max);
+ }
+ logsumexp = Math.log(logsumexp) + max;
+ return (float) logsumexp;
+ }
+
+ public static double logsumexp(@Nonnull final double[] arr) {
+ if (arr.length == 0) {
+ return 0.d;
+ }
+ double max = 0.d;
+ for (final double v : arr) {
+ if (v > max) {
+ max = v;
+ }
+ }
+ return logsumexp(arr, max);
+ }
+
+ public static double logsumexp(@Nonnull final double[] arr, final double max) {
+ double logsumexp = 0.d;
+ for (final double v : arr) {
+ logsumexp += Math.exp(v - max);
+ }
+ logsumexp = Math.log(logsumexp) + max;
+ return logsumexp;
+ }
+
}