You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/27 06:44:59 UTC
[1/3] incubator-hivemall git commit: Close #89: [HIVEMALL-120]
Refactor on LDA/pLSA's mini-batch & buffered iteration logic
Repository: incubator-hivemall
Updated Branches:
refs/heads/master bfc5b75b0 -> 9f01ebf20
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
index addacbc..e5045a5 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAUDTFTest.java
@@ -54,7 +54,7 @@ public class PLSAUDTFTest {
udtf.process(new Object[] {Arrays.asList(doc1)});
udtf.process(new Object[] {Arrays.asList(doc2)});
- udtf.closeWithoutModelReset();
+ udtf.finalizeTraining();
SortedMap<Float, List<String>> topicWords;
@@ -93,10 +93,10 @@ public class PLSAUDTFTest {
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
- udtf.getProbability("vegetables", k1) > udtf.getProbability("flu", k1));
+ udtf.getWordScore("vegetables", k1) > udtf.getWordScore("flu", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
- udtf.getProbability("avocados", k2) > udtf.getProbability("healthy", k2));
+ udtf.getWordScore("avocados", k2) > udtf.getWordScore("healthy", k2));
}
@Test
@@ -107,7 +107,7 @@ public class PLSAUDTFTest {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000")};
+ "-topics 2 -alpha 0.1 -delta 0.00001 -iter 10000 -mini_batch_size 1")};
udtf.initialize(argOIs);
@@ -117,7 +117,7 @@ public class PLSAUDTFTest {
udtf.process(new Object[] {Arrays.asList(doc1)});
udtf.process(new Object[] {Arrays.asList(doc2)});
- udtf.closeWithoutModelReset();
+ udtf.finalizeTraining();
SortedMap<Float, List<String>> topicWords;
@@ -156,10 +156,10 @@ public class PLSAUDTFTest {
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `野菜` SHOULD be more suitable topic word than `インフルエンザ` in the topic",
- udtf.getProbability("野菜", k1) > udtf.getProbability("インフルエンザ", k1));
+ udtf.getWordScore("野菜", k1) > udtf.getWordScore("インフルエンザ", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `アボカド` SHOULD be more suitable topic word than `健康` in the topic",
- udtf.getProbability("アボカド", k2) > udtf.getProbability("健康", k2));
+ udtf.getWordScore("アボカド", k2) > udtf.getWordScore("健康", k2));
}
private static void println(String msg) {
[3/3] incubator-hivemall git commit: Applied refactoring for
topicmodel module
Posted by my...@apache.org.
Applied refactoring for topicmodel module
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9f01ebf2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9f01ebf2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9f01ebf2
Branch: refs/heads/master
Commit: 9f01ebf20c74559be8a50d459103118a51c229bf
Parents: 0495ffa
Author: Makoto Yui <my...@apache.org>
Authored: Tue Jun 27 15:44:31 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 15:44:31 2017 +0900
----------------------------------------------------------------------
.../AbstractProbabilisticTopicModel.java | 26 +++++++++++++-------
.../topicmodel/IncrementalPLSAModel.java | 16 ++++++------
.../main/java/hivemall/topicmodel/LDAUDTF.java | 5 +---
.../hivemall/topicmodel/OnlineLDAModel.java | 18 +++++++-------
.../main/java/hivemall/topicmodel/PLSAUDTF.java | 5 +---
.../ProbabilisticTopicModelBaseUDTF.java | 25 ++++++++++++-------
6 files changed, 52 insertions(+), 43 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
index 3c097e2..1b7f3e8 100644
--- a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -21,9 +21,14 @@ package hivemall.topicmodel;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
-import java.util.*;
public abstract class AbstractProbabilisticTopicModel {
@@ -31,6 +36,7 @@ public abstract class AbstractProbabilisticTopicModel {
protected final int _K;
// total number of documents
+ @Nonnegative
protected long _D;
// for mini-batch
@@ -38,7 +44,7 @@ public abstract class AbstractProbabilisticTopicModel {
protected final List<Map<String, Float>> _miniBatchDocs;
protected int _miniBatchSize;
- public AbstractProbabilisticTopicModel(int K) {
+ public AbstractProbabilisticTopicModel(@Nonnegative int K) {
this._K = K;
this._D = 0L;
this._miniBatchDocs = new ArrayList<Map<String, Float>>();
@@ -73,26 +79,28 @@ public abstract class AbstractProbabilisticTopicModel {
}
}
- public void accumulateDocCount() {
+ protected void accumulateDocCount() {
this._D += 1;
}
- public long getDocCount() {
+ @Nonnegative
+ protected long getDocCount() {
return _D;
}
- public abstract void train(@Nonnull final String[][] miniBatch);
+ protected abstract void train(@Nonnull final String[][] miniBatch);
- public abstract float computePerplexity();
+ protected abstract float computePerplexity();
@Nonnull
- public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+ protected abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
@Nonnull
- public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+ protected abstract float[] getTopicDistribution(@Nonnull final String[] doc);
@VisibleForTesting
abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
- public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+ protected abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic,
+ final float score);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index b99e670..6419664 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,7 +20,6 @@ package hivemall.topicmodel;
import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
import static hivemall.utils.math.MathUtils.l1normalize;
-
import hivemall.annotations.VisibleForTesting;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
@@ -59,9 +58,10 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair
// optimized in the M step
- @Nonnull
private List<float[]> _p_dz; // P(z|d) probability of topics for documents
- private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
+
+ @Nonnull
+ private final Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
public IncrementalPLSAModel(int K, float alpha, double delta) {
super(K);
@@ -74,7 +74,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
this._p_zw = new HashMap<String, float[]>();
}
- public void train(@Nonnull final String[][] miniBatch) {
+ protected void train(@Nonnull final String[][] miniBatch) {
initMiniBatch(miniBatch, _miniBatchDocs);
this._miniBatchSize = _miniBatchDocs.size();
@@ -211,7 +211,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
return (diff / _K) < _delta;
}
- public float computePerplexity() {
+ protected float computePerplexity() {
double numer = 0.d;
double denom = 0.d;
@@ -241,7 +241,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
}
@Nonnull
- public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
+ protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) {
final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>(
Collections.reverseOrder());
@@ -261,7 +261,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
}
@Nonnull
- public float[] getTopicDistribution(@Nonnull final String[] doc) {
+ protected float[] getTopicDistribution(@Nonnull final String[] doc) {
train(new String[][] {doc});
return _p_dz.get(0);
}
@@ -271,7 +271,7 @@ public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel
return _p_zw.get(w)[z];
}
- public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+ protected void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
float[] prob_label = _p_zw.get(w);
if (prob_label == null) {
prob_label = newRandomFloatArray(_K, _rnd);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index 41386a4..9bac908 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
- private static final Log logger = LogFactory.getLog(LDAUDTF.class);
+public final class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
public static final double DEFAULT_DELTA = 1E-3d;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 4a7531c..6a8d6db 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -38,6 +38,9 @@ import org.apache.commons.math3.special.Gamma;
public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
+ private static final double SHAPE = 100.d;
+ private static final double SCALE = 1.d / SHAPE;
+
// ---------------------------------
// HyperParameters
@@ -72,7 +75,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
private final boolean _isAutoD;
// parameters
- @Nonnull
private List<Map<String, float[]>> _phi;
private float[][] _gamma;
@Nonnull
@@ -81,8 +83,6 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
// random number generator
@Nonnull
private final GammaDistribution _gd;
- private static final double SHAPE = 100.d;
- private static final double SCALE = 1.d / SHAPE;
// for computing perplexity
private float _docRatio = 1.f;
@@ -121,7 +121,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Override
- public void accumulateDocCount() {
+ protected void accumulateDocCount() {
/*
* In a truly online setting, total number of documents equals to the number of documents that have ever seen.
* In that case, users need to manually set the current max number of documents via this method.
@@ -133,7 +133,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
}
- public void train(@Nonnull final String[][] miniBatch) {
+ protected void train(@Nonnull final String[][] miniBatch) {
preprocessMiniBatch(miniBatch);
initParams(true);
@@ -341,7 +341,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
/**
* Calculate approximate perplexity for the current mini-batch.
*/
- public float computePerplexity() {
+ protected float computePerplexity() {
double bound = computeApproxBound();
double perWordBound = bound / (_docRatio * _valueSum);
return (float) Math.exp(-1.d * perWordBound);
@@ -449,7 +449,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
return lambda_label[k];
}
- public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
+ protected void setWordScore(@Nonnull final String label, @Nonnegative final int k,
final float lambda_k) {
float[] lambda_label = _lambda.get(label);
if (lambda_label == null) {
@@ -460,7 +460,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Nonnull
- public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
+ protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k) {
return getTopicWords(k, _lambda.keySet().size());
}
@@ -501,7 +501,7 @@ public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
}
@Nonnull
- public float[] getTopicDistribution(@Nonnull final String[] doc) {
+ protected float[] getTopicDistribution(@Nonnull final String[] doc) {
preprocessMiniBatch(new String[][] {doc});
initParams(false);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index e1d8797..9c5a0ea 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -22,16 +22,13 @@ import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
- private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
+public final class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
public static final float DEFAULT_ALPHA = 0.5f;
public static final double DEFAULT_DELTA = 1E-3d;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9f01ebf2/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
index cff076e..c3dab89 100644
--- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -27,6 +27,19 @@ import hivemall.utils.io.NioStatefullSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
@@ -44,13 +57,6 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.*;
-
public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
@@ -143,6 +149,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
+ @Nonnull
protected abstract AbstractProbabilisticTopicModel createModel();
@Override
@@ -157,7 +164,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
for (int i = 0; i < length; i++) {
Object o = wordCountsOI.getListElement(args[0], i);
if (o == null) {
- throw new HiveException("Given feature vector contains invalid elements");
+ throw new HiveException("Given feature vector contains invalid null elements");
}
String s = o.toString();
wordCounts[j] = s;
@@ -167,7 +174,7 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
return;
}
- model.accumulateDocCount();;
+ model.accumulateDocCount();
update(wordCounts);
[2/3] incubator-hivemall git commit: Close #89: [HIVEMALL-120]
Refactor on LDA/pLSA's mini-batch & buffered iteration logic
Posted by my...@apache.org.
Close #89: [HIVEMALL-120] Refactor on LDA/pLSA's mini-batch & buffered iteration logic
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/0495ffad
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/0495ffad
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/0495ffad
Branch: refs/heads/master
Commit: 0495ffadbc42bffa36cb583622708ae1fa65a44e
Parents: bfc5b75
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Tue Jun 27 13:53:45 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jun 27 13:53:45 2017 +0900
----------------------------------------------------------------------
.../AbstractProbabilisticTopicModel.java | 98 ++++
.../topicmodel/IncrementalPLSAModel.java | 51 +-
.../hivemall/topicmodel/LDAPredictUDAF.java | 2 +-
.../main/java/hivemall/topicmodel/LDAUDTF.java | 503 +------------------
.../hivemall/topicmodel/OnlineLDAModel.java | 82 +--
.../hivemall/topicmodel/PLSAPredictUDAF.java | 2 +-
.../main/java/hivemall/topicmodel/PLSAUDTF.java | 490 +-----------------
.../ProbabilisticTopicModelBaseUDTF.java | 487 ++++++++++++++++++
.../topicmodel/IncrementalPLSAModelTest.java | 8 +-
.../java/hivemall/topicmodel/LDAUDTFTest.java | 14 +-
.../hivemall/topicmodel/OnlineLDAModelTest.java | 4 +-
.../java/hivemall/topicmodel/PLSAUDTFTest.java | 14 +-
12 files changed, 653 insertions(+), 1102 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
new file mode 100644
index 0000000..3c097e2
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/AbstractProbabilisticTopicModel.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.annotations.VisibleForTesting;
+import hivemall.model.FeatureValue;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.*;
+
+public abstract class AbstractProbabilisticTopicModel {
+
+ // number of topics
+ protected final int _K;
+
+ // total number of documents
+ protected long _D;
+
+ // for mini-batch
+ @Nonnull
+ protected final List<Map<String, Float>> _miniBatchDocs;
+ protected int _miniBatchSize;
+
+ public AbstractProbabilisticTopicModel(int K) {
+ this._K = K;
+ this._D = 0L;
+ this._miniBatchDocs = new ArrayList<Map<String, Float>>();
+ }
+
+ protected static void initMiniBatch(@Nonnull final String[][] miniBatch,
+ @Nonnull final List<Map<String, Float>> docs) {
+ docs.clear();
+
+ final FeatureValue probe = new FeatureValue();
+
+ // parse document
+ for (final String[] e : miniBatch) {
+ if (e == null || e.length == 0) {
+ continue;
+ }
+
+ final Map<String, Float> doc = new HashMap<String, Float>();
+
+ // parse features
+ for (String fv : e) {
+ if (fv == null) {
+ continue;
+ }
+ FeatureValue.parseFeatureAsString(fv, probe);
+ String label = probe.getFeatureAsString();
+ float value = probe.getValueAsFloat();
+ doc.put(label, Float.valueOf(value));
+ }
+
+ docs.add(doc);
+ }
+ }
+
+ public void accumulateDocCount() {
+ this._D += 1;
+ }
+
+ public long getDocCount() {
+ return _D;
+ }
+
+ public abstract void train(@Nonnull final String[][] miniBatch);
+
+ public abstract float computePerplexity();
+
+ @Nonnull
+ public abstract SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int k);
+
+ @Nonnull
+ public abstract float[] getTopicDistribution(@Nonnull final String[] doc);
+
+ @VisibleForTesting
+ abstract float getWordScore(@Nonnull final String word, @Nonnegative final int topic);
+
+ public abstract void setWordScore(@Nonnull final String word, @Nonnegative final int topic, final float score);
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
index 6eef23e..b99e670 100644
--- a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java
@@ -20,9 +20,10 @@ package hivemall.topicmodel;
import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray;
import static hivemall.utils.math.MathUtils.l1normalize;
+
+import hivemall.annotations.VisibleForTesting;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
-import hivemall.model.FeatureValue;
import hivemall.utils.math.MathUtils;
import java.util.ArrayList;
@@ -37,14 +38,11 @@ import java.util.TreeMap;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
-public final class IncrementalPLSAModel {
+public final class IncrementalPLSAModel extends AbstractProbabilisticTopicModel {
// ---------------------------------
// HyperParameters
- // number of topics
- private final int _K;
-
// control how much P(w|z) update is affected by the last value
private final float _alpha;
@@ -65,20 +63,15 @@ public final class IncrementalPLSAModel {
private List<float[]> _p_dz; // P(z|d) probability of topics for documents
private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic
- @Nonnull
- private final List<Map<String, Float>> _miniBatchDocs;
- private int _miniBatchSize;
-
public IncrementalPLSAModel(int K, float alpha, double delta) {
- this._K = K;
+ super(K);
+
this._alpha = alpha;
this._delta = delta;
this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
this._p_zw = new HashMap<String, float[]>();
-
- this._miniBatchDocs = new ArrayList<Map<String, Float>>();
}
public void train(@Nonnull final String[][] miniBatch) {
@@ -106,35 +99,6 @@ public final class IncrementalPLSAModel {
}
}
- private static void initMiniBatch(@Nonnull final String[][] miniBatch,
- @Nonnull final List<Map<String, Float>> docs) {
- docs.clear();
-
- final FeatureValue probe = new FeatureValue();
-
- // parse document
- for (final String[] e : miniBatch) {
- if (e == null || e.length == 0) {
- continue;
- }
-
- final Map<String, Float> doc = new HashMap<String, Float>();
-
- // parse features
- for (String fv : e) {
- if (fv == null) {
- continue;
- }
- FeatureValue.parseFeatureAsString(fv, probe);
- String word = probe.getFeatureAsString();
- float value = probe.getValueAsFloat();
- doc.put(word, Float.valueOf(value));
- }
-
- docs.add(doc);
- }
- }
-
private void initParams() {
final List<float[]> p_dz = new ArrayList<float[]>();
final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>();
@@ -302,11 +266,12 @@ public final class IncrementalPLSAModel {
return _p_dz.get(0);
}
- public float getProbability(@Nonnull final String w, @Nonnegative final int z) {
+ @VisibleForTesting
+ float getWordScore(@Nonnull final String w, @Nonnegative final int z) {
return _p_zw.get(w)[z];
}
- public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) {
+ public void setWordScore(@Nonnull final String w, @Nonnegative final int z, final float prob) {
float[] prob_label = _p_zw.get(w);
if (prob_label == null) {
prob_label = newRandomFloatArray(_K, _rnd);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 03779b0..94d510a 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -471,7 +471,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
for (int k = 0; k < topics; k++) {
final float lambda_k = lambda_word.get(k).floatValue();
if (lambda_k != -1.f) {
- model.setLambda(word, k, lambda_k);
+ model.setWordScore(word, k, lambda_k);
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
index de57518..41386a4 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java
@@ -18,27 +18,7 @@
*/
package hivemall.topicmodel;
-import hivemall.UDTFWithOptions;
-import hivemall.annotations.VisibleForTesting;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.FileUtils;
-import hivemall.utils.io.NIOUtils;
-import hivemall.utils.io.NioStatefullSegment;
-import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
-import hivemall.utils.lang.SizeOf;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
@@ -46,96 +26,52 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.FloatWritable;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.Reporter;
@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class LDAUDTF extends UDTFWithOptions {
+public class LDAUDTF extends ProbabilisticTopicModelBaseUDTF {
private static final Log logger = LogFactory.getLog(LDAUDTF.class);
- public static final int DEFAULT_TOPICS = 10;
public static final double DEFAULT_DELTA = 1E-3d;
// Options
- protected int topics;
protected float alpha;
protected float eta;
protected long numDocs;
protected double tau0;
protected double kappa;
- protected int iterations;
protected double delta;
- protected double eps;
- protected int miniBatchSize;
-
- // if `num_docs` option is not given, this flag will be true
- // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
- protected boolean isAutoD;
-
- // number of proceeded training samples
- protected long count;
-
- protected String[][] miniBatch;
- protected int miniBatchCount;
-
- protected transient OnlineLDAModel model;
-
- protected ListObjectInspector wordCountsOI;
-
- // for iterations
- protected NioStatefullSegment fileIO;
- protected ByteBuffer inputBuf;
public LDAUDTF() {
- this.topics = DEFAULT_TOPICS;
+ super();
+
this.alpha = 1.f / topics;
this.eta = 1.f / topics;
this.numDocs = -1L;
this.tau0 = 64.d;
this.kappa = 0.7;
- this.iterations = 10;
this.delta = DEFAULT_DELTA;
- this.eps = 1E-1d;
- this.miniBatchSize = 128; // if 1, truly online setting
}
@Override
protected Options getOptions() {
- Options opts = new Options();
- opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+ Options opts = super.getOptions();
opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]");
opts.addOption("tau", "tau0", true,
"The parameter which downweights early iterations [default: 64.0]");
opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default: 0.7]");
- opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
- opts.addOption("eps", "epsilon", true,
- "Check convergence based on the difference of perplexity [default: 1E-1]");
- opts.addOption("s", "mini_batch_size", true,
- "Repeat model updating per mini-batch [default: 128]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = null;
+ CommandLine cl = super.processOptions(argOIs);
- if (argOIs.length >= 2) {
- String rawArgs = HiveUtils.getConstString(argOIs[1]);
- cl = parseOptions(rawArgs);
- this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+ if (cl != null) {
this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics);
this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics);
this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L);
@@ -147,436 +83,13 @@ public class LDAUDTF extends UDTFWithOptions {
if (kappa <= 0.5 || kappa > 1.d) {
throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa);
}
- this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
- if (iterations < 1) {
- throw new UDFArgumentException(
- "'-iterations' must be greater than or equals to 1: " + iterations);
- }
this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
- this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
- this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
}
return cl;
}
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length < 1) {
- throw new UDFArgumentException(
- "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
- }
-
- this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
- HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
-
- processOptions(argOIs);
-
- this.model = null;
- this.count = 0L;
- this.isAutoD = (numDocs < 0L);
- this.miniBatch = new String[miniBatchSize][];
- this.miniBatchCount = 0;
-
- ArrayList<String> fieldNames = new ArrayList<String>();
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
- fieldNames.add("topic");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("word");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("score");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
-
- return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
- }
-
- protected void initModel() {
- this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
- }
-
- @Override
- public void process(Object[] args) throws HiveException {
- if (model == null) {
- initModel();
- }
-
- final int length = wordCountsOI.getListLength(args[0]);
- final String[] wordCounts = new String[length];
- int j = 0;
- for (int i = 0; i < length; i++) {
- Object o = wordCountsOI.getListElement(args[0], i);
- if (o == null) {
- throw new HiveException("Given feature vector contains invalid elements");
- }
- String s = o.toString();
- wordCounts[j] = s;
- j++;
- }
- if (j == 0) {// avoid empty documents
- return;
- }
-
- count++;
- if (isAutoD) {
- model.setNumTotalDocs(count);
- }
-
- recordTrainSampleToTempFile(wordCounts);
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
- }
-
- protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
- throws HiveException {
- if (iterations == 1) {
- return;
- }
-
- ByteBuffer buf = inputBuf;
- NioStatefullSegment dst = fileIO;
-
- if (buf == null) {
- final File file;
- try {
- file = File.createTempFile("hivemall_lda", ".sgmt");
- file.deleteOnExit();
- if (!file.canWrite()) {
- throw new UDFArgumentException("Cannot write a temporary file: "
- + file.getAbsolutePath());
- }
- logger.info("Record training samples to a file: " + file.getAbsolutePath());
- } catch (IOException ioe) {
- throw new UDFArgumentException(ioe);
- } catch (Throwable e) {
- throw new UDFArgumentException(e);
- }
- this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
- this.fileIO = dst = new NioStatefullSegment(file, false);
- }
-
- // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
- int wcLengthTotal = 0;
- for (String wc : wordCounts) {
- if (wc == null) {
- continue;
- }
- wcLengthTotal += wc.length();
- }
- int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal
- * SizeOf.CHAR;
-
- int remain = buf.remaining();
- if (remain < requiredRecordBytes) {
- writeBuffer(buf, dst);
- }
-
- buf.putInt(requiredRecordBytes);
- buf.putInt(wordCounts.length);
- for (String wc : wordCounts) {
- NIOUtils.putString(wc, buf);
- }
- }
-
- private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
- throws HiveException {
- srcBuf.flip();
- try {
- dst.write(srcBuf);
- } catch (IOException e) {
- throw new HiveException("Exception causes while writing a buffer to file", e);
- }
- srcBuf.clear();
- }
-
- @Override
- public void close() throws HiveException {
- if (count == 0) {
- this.model = null;
- return;
- }
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- }
- if (iterations > 1) {
- runIterativeTraining(iterations);
- }
- forwardModel();
- this.model = null;
- }
-
- protected final void runIterativeTraining(@Nonnegative final int iterations)
- throws HiveException {
- final ByteBuffer buf = this.inputBuf;
- final NioStatefullSegment dst = this.fileIO;
- assert (buf != null);
- assert (dst != null);
- final long numTrainingExamples = count;
-
- final Reporter reporter = getReporter();
- final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
- "hivemall.lda.OnlineLDA$Counter", "iteration");
-
- try {
- if (dst.getPosition() == 0L) {// run iterations w/o temporary file
- if (buf.position() == 0) {
- return; // no training example
- }
- buf.flip();
-
- int iter = 2;
- float perplexityPrev = Float.MAX_VALUE;
- float perplexity;
- int numTrain;
- for (; iter <= iterations; iter++) {
- perplexity = 0.f;
- numTrain = 0;
-
- reportProgress(reporter);
- setCounterValue(iterCounter, iter);
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
-
- while (buf.remaining() > 0) {
- int recordBytes = buf.getInt();
- assert (recordBytes > 0) : recordBytes;
- int wcLength = buf.getInt();
- final String[] wordCounts = new String[wcLength];
- for (int j = 0; j < wcLength; j++) {
- wordCounts[j] = NIOUtils.getString(buf);
- }
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- perplexity += model.computePerplexity();
- numTrain++;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
- }
- buf.rewind();
-
- // update for remaining samples
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- perplexity += model.computePerplexity();
- numTrain++;
- }
-
- logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
- perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
- if (Math.abs(perplexityPrev - perplexity) < eps) {
- break;
- }
- perplexityPrev = perplexity;
- }
- logger.info("Performed "
- + Math.min(iter, iterations)
- + " iterations of "
- + NumberUtils.formatNumber(numTrainingExamples)
- + " training examples on memory (thus "
- + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
- + " training updates in total) ");
- } else {// read training examples in the temporary file and invoke train for each example
- // write training examples in buffer to a temporary file
- if (buf.remaining() > 0) {
- writeBuffer(buf, dst);
- }
- try {
- dst.flush();
- } catch (IOException e) {
- throw new HiveException("Failed to flush a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- if (logger.isInfoEnabled()) {
- File tmpFile = dst.getFile();
- logger.info("Wrote " + numTrainingExamples
- + " records to a temporary file for iterative training: "
- + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
- + ")");
- }
-
- // run iterations
- int iter = 2;
- float perplexityPrev = Float.MAX_VALUE;
- float perplexity;
- int numTrain;
- for (; iter <= iterations; iter++) {
- perplexity = 0.f;
- numTrain = 0;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
-
- setCounterValue(iterCounter, iter);
-
- buf.clear();
- dst.resetPosition();
- while (true) {
- reportProgress(reporter);
- // TODO prefetch
- // writes training examples to a buffer in the temporary file
- final int bytesRead;
- try {
- bytesRead = dst.read(buf);
- } catch (IOException e) {
- throw new HiveException("Failed to read a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- if (bytesRead == 0) { // reached file EOF
- break;
- }
- assert (bytesRead > 0) : bytesRead;
-
- // reads training examples from a buffer
- buf.flip();
- int remain = buf.remaining();
- if (remain < SizeOf.INT) {
- throw new HiveException("Illegal file format was detected");
- }
- while (remain >= SizeOf.INT) {
- int pos = buf.position();
- int recordBytes = buf.getInt() - SizeOf.INT;
- remain -= SizeOf.INT;
- if (remain < recordBytes) {
- buf.position(pos);
- break;
- }
-
- int wcLength = buf.getInt();
- final String[] wordCounts = new String[wcLength];
- for (int j = 0; j < wcLength; j++) {
- wordCounts[j] = NIOUtils.getString(buf);
- }
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- perplexity += model.computePerplexity();
- numTrain++;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
-
- remain -= recordBytes;
- }
- buf.compact();
- }
-
- // update for remaining samples
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- perplexity += model.computePerplexity();
- numTrain++;
- }
-
- logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
- perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
- if (Math.abs(perplexityPrev - perplexity) < eps) {
- break;
- }
- perplexityPrev = perplexity;
- }
- logger.info("Performed "
- + Math.min(iter, iterations)
- + " iterations of "
- + NumberUtils.formatNumber(numTrainingExamples)
- + " training examples on a secondary storage (thus "
- + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
- + " training updates in total)");
- }
- } catch (Throwable e) {
- throw new HiveException("Exception caused in the iterative training", e);
- } finally {
- // delete the temporary file and release resources
- try {
- dst.close(true);
- } catch (IOException e) {
- throw new HiveException("Failed to close a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- this.inputBuf = null;
- this.fileIO = null;
- }
- }
-
- protected void forwardModel() throws HiveException {
- final IntWritable topicIdx = new IntWritable();
- final Text word = new Text();
- final FloatWritable score = new FloatWritable();
-
- final Object[] forwardObjs = new Object[3];
- forwardObjs[0] = topicIdx;
- forwardObjs[1] = word;
- forwardObjs[2] = score;
-
- for (int k = 0; k < topics; k++) {
- topicIdx.set(k);
-
- final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
- for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
- score.set(e.getKey());
- List<String> words = e.getValue();
- for (int i = 0; i < words.size(); i++) {
- word.set(words.get(i));
- forward(forwardObjs);
- }
- }
- }
-
- logger.info("Forwarded topic words each of " + topics + " topics");
- }
-
- /*
- * For testing:
- */
-
- @VisibleForTesting
- void closeWithoutModelReset() throws HiveException {
- // launch close(), but not forward & clear model
- if (count == 0) {
- this.model = null;
- return;
- }
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- }
- if (iterations > 1) {
- runIterativeTraining(iterations);
- }
- }
-
- @VisibleForTesting
- double getLambda(String label, int k) {
- return model.getLambda(label, k);
- }
-
- @VisibleForTesting
- SortedMap<Float, List<String>> getTopicWords(int k) {
- return model.getTopicWords(k);
- }
-
- @VisibleForTesting
- SortedMap<Float, List<String>> getTopicWords(int k, int topN) {
- return model.getTopicWords(k, topN);
- }
-
- @VisibleForTesting
- float[] getTopicDistribution(@Nonnull String[] doc) {
- return model.getTopicDistribution(doc);
+ protected AbstractProbabilisticTopicModel createModel() {
+ return new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
index 8fef10c..4a7531c 100644
--- a/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
+++ b/core/src/main/java/hivemall/topicmodel/OnlineLDAModel.java
@@ -19,7 +19,6 @@
package hivemall.topicmodel;
import hivemall.annotations.VisibleForTesting;
-import hivemall.model.FeatureValue;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.math.MathUtils;
@@ -37,24 +36,17 @@ import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;
-public final class OnlineLDAModel {
+public final class OnlineLDAModel extends AbstractProbabilisticTopicModel {
// ---------------------------------
// HyperParameters
- // number of topics
- private final int _K;
-
// prior on weight vectors "theta ~ Dir(alpha_)"
private final float _alpha;
// prior on topics "beta"
private final float _eta;
- // total number of documents
- // in the truly online setting, this can be an estimate of the maximum number of documents that could ever seen
- private long _D = -1L;
-
// positive value which downweights early iterations
@Nonnegative
private final double _tau0;
@@ -75,6 +67,10 @@ public final class OnlineLDAModel {
// controls how much old lambda is forgotten
private double _rhot;
+ // if `num_docs` option is not given, this flag will be true
+ // in that case, UDTF automatically sets `count` value to the _D parameter in an online LDA model
+ private final boolean _isAutoD;
+
// parameters
@Nonnull
private List<Map<String, float[]>> _phi;
@@ -88,11 +84,6 @@ public final class OnlineLDAModel {
private static final double SHAPE = 100.d;
private static final double SCALE = 1.d / SHAPE;
- // for mini-batch
- @Nonnull
- private final List<Map<String, Float>> _miniBatchDocs;
- private int _miniBatchSize;
-
// for computing perplexity
private float _docRatio = 1.f;
private double _valueSum = 0.d;
@@ -103,6 +94,8 @@ public final class OnlineLDAModel {
public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa,
double delta) {
+ super(K);
+
if (tau0 < 0.d) {
throw new IllegalArgumentException("tau0 MUST be positive: " + tau0);
}
@@ -110,7 +103,6 @@ public final class OnlineLDAModel {
throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa);
}
- this._K = K;
this._alpha = alpha;
this._eta = eta;
this._D = D;
@@ -118,31 +110,30 @@ public final class OnlineLDAModel {
this._kappa = kappa;
this._delta = delta;
+ this._isAutoD = (_D < 0L);
+
// initialize a random number generator
this._gd = new GammaDistribution(SHAPE, SCALE);
_gd.reseedRandomGenerator(1001);
// initialize the parameters
this._lambda = new HashMap<String, float[]>(100);
-
- this._miniBatchDocs = new ArrayList<Map<String, Float>>();
}
- /**
- * In a truly online setting, total number of documents corresponds to the number of documents that have ever seen. In that case, users need to
- * manually set the current max number of documents via this method. Note that, since the same set of documents could be repeatedly passed to
- * `train()`, simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
- */
- public void setNumTotalDocs(@Nonnegative long D) {
- this._D = D;
+ @Override
+ public void accumulateDocCount() {
+ /*
+ * In a truly online setting, total number of documents equals to the number of documents that have ever seen.
+ * In that case, users need to manually set the current max number of documents via this method.
+ * Note that, since the same set of documents could be repeatedly passed to `train()`,
+ * simply accumulating `_miniBatchSize`s as estimated `_D` is not sufficient.
+ */
+ if (_isAutoD) {
+ this._D += 1;
+ }
}
public void train(@Nonnull final String[][] miniBatch) {
- if (_D <= 0L) {
- throw new IllegalStateException(
- "Total number of documents MUST be set via `setNumTotalDocs()`");
- }
-
preprocessMiniBatch(miniBatch);
initParams(true);
@@ -175,35 +166,6 @@ public final class OnlineLDAModel {
this._docRatio = (float) ((double) _D / _miniBatchSize);
}
- private static void initMiniBatch(@Nonnull final String[][] miniBatch,
- @Nonnull final List<Map<String, Float>> docs) {
- docs.clear();
-
- final FeatureValue probe = new FeatureValue();
-
- // parse document
- for (final String[] e : miniBatch) {
- if (e == null || e.length == 0) {
- continue;
- }
-
- final Map<String, Float> doc = new HashMap<String, Float>();
-
- // parse features
- for (String fv : e) {
- if (fv == null) {
- continue;
- }
- FeatureValue.parseFeatureAsString(fv, probe);
- String label = probe.getFeatureAsString();
- float value = probe.getValueAsFloat();
- doc.put(label, Float.valueOf(value));
- }
-
- docs.add(doc);
- }
- }
-
private void initParams(final boolean gammaWithRandom) {
final List<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>();
final float[][] gamma = new float[_miniBatchSize][];
@@ -475,7 +437,7 @@ public final class OnlineLDAModel {
}
@VisibleForTesting
- double getLambda(@Nonnull final String label, @Nonnegative final int k) {
+ float getWordScore(@Nonnull final String label, @Nonnegative final int k) {
final float[] lambda_label = _lambda.get(label);
if (lambda_label == null) {
throw new IllegalArgumentException("Word `" + label + "` is not in the corpus.");
@@ -487,7 +449,7 @@ public final class OnlineLDAModel {
return lambda_label[k];
}
- public void setLambda(@Nonnull final String label, @Nonnegative final int k,
+ public void setWordScore(@Nonnull final String label, @Nonnegative final int k,
final float lambda_k) {
float[] lambda_label = _lambda.get(label);
if (lambda_label == null) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index ff29236..7702945 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -474,7 +474,7 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
for (int k = 0; k < topics; k++) {
final float prob_k = prob_word.get(k).floatValue();
if (prob_k != -1.f) {
- model.setProbability(word, k, prob_k);
+ model.setWordScore(word, k, prob_k);
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
index 46f731f..e1d8797 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java
@@ -18,27 +18,7 @@
*/
package hivemall.topicmodel;
-import hivemall.UDTFWithOptions;
-import hivemall.annotations.VisibleForTesting;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.FileUtils;
-import hivemall.utils.io.NIOUtils;
-import hivemall.utils.io.NioStatefullSegment;
-import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
-import hivemall.utils.lang.SizeOf;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
@@ -46,503 +26,49 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.FloatWritable;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.Reporter;
@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])"
+ " - Returns a relation consists of <int topic, string word, float score>")
-public class PLSAUDTF extends UDTFWithOptions {
+public class PLSAUDTF extends ProbabilisticTopicModelBaseUDTF {
private static final Log logger = LogFactory.getLog(PLSAUDTF.class);
- public static final int DEFAULT_TOPICS = 10;
public static final float DEFAULT_ALPHA = 0.5f;
public static final double DEFAULT_DELTA = 1E-3d;
// Options
- protected int topics;
protected float alpha;
- protected int iterations;
protected double delta;
- protected double eps;
- protected int miniBatchSize;
-
- // number of proceeded training samples
- protected long count;
-
- protected String[][] miniBatch;
- protected int miniBatchCount;
-
- protected transient IncrementalPLSAModel model;
-
- protected ListObjectInspector wordCountsOI;
-
- // for iterations
- protected NioStatefullSegment fileIO;
- protected ByteBuffer inputBuf;
public PLSAUDTF() {
- this.topics = DEFAULT_TOPICS;
+ super();
+
this.alpha = DEFAULT_ALPHA;
- this.iterations = 10;
this.delta = DEFAULT_DELTA;
- this.eps = 1E-1d;
- this.miniBatchSize = 128;
}
@Override
protected Options getOptions() {
- Options opts = new Options();
- opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+ Options opts = super.getOptions();
opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
- opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]");
- opts.addOption("eps", "epsilon", true,
- "Check convergence based on the difference of perplexity [default: 1E-1]");
- opts.addOption("s", "mini_batch_size", true,
- "Repeat model updating per mini-batch [default: 128]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = null;
+ CommandLine cl = super.processOptions(argOIs);
- if (argOIs.length >= 2) {
- String rawArgs = HiveUtils.getConstString(argOIs[1]);
- cl = parseOptions(rawArgs);
- this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+ if (cl != null) {
this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA);
- this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
- if (iterations < 1) {
- throw new UDFArgumentException(
- "'-iterations' must be greater than or equals to 1: " + iterations);
- }
this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA);
- this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
- this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
}
return cl;
}
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length < 1) {
- throw new UDFArgumentException(
- "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
- }
-
- this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
- HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
-
- processOptions(argOIs);
-
- this.model = null;
- this.count = 0L;
- this.miniBatch = new String[miniBatchSize][];
- this.miniBatchCount = 0;
-
- ArrayList<String> fieldNames = new ArrayList<String>();
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
- fieldNames.add("topic");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("word");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("score");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
-
- return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
- }
-
- protected void initModel() {
- this.model = new IncrementalPLSAModel(topics, alpha, delta);
- }
-
- @Override
- public void process(Object[] args) throws HiveException {
- if (model == null) {
- initModel();
- }
-
- int length = wordCountsOI.getListLength(args[0]);
- String[] wordCounts = new String[length];
- int j = 0;
- for (int i = 0; i < length; i++) {
- Object o = wordCountsOI.getListElement(args[0], i);
- if (o == null) {
- throw new HiveException("Given feature vector contains invalid elements");
- }
- String s = o.toString();
- wordCounts[j] = s;
- j++;
- }
- if (j == 0) {// avoid empty documents
- return;
- }
-
- count++;
-
- recordTrainSampleToTempFile(wordCounts);
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
- }
-
- protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
- throws HiveException {
- if (iterations == 1) {
- return;
- }
-
- ByteBuffer buf = inputBuf;
- NioStatefullSegment dst = fileIO;
-
- if (buf == null) {
- final File file;
- try {
- file = File.createTempFile("hivemall_plsa", ".sgmt");
- file.deleteOnExit();
- if (!file.canWrite()) {
- throw new UDFArgumentException("Cannot write a temporary file: "
- + file.getAbsolutePath());
- }
- logger.info("Record training samples to a file: " + file.getAbsolutePath());
- } catch (IOException ioe) {
- throw new UDFArgumentException(ioe);
- } catch (Throwable e) {
- throw new UDFArgumentException(e);
- }
- this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
- this.fileIO = dst = new NioStatefullSegment(file, false);
- }
-
- // requiredRecordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
- int wcLengthTotal = 0;
- for (String wc : wordCounts) {
- if (wc == null) {
- continue;
- }
- wcLengthTotal += wc.length();
- }
- int requiredRecordBytes = SizeOf.INT * 2 + SizeOf.INT * wordCounts.length + wcLengthTotal
- * SizeOf.CHAR;
-
- int remain = buf.remaining();
- if (remain < requiredRecordBytes) {
- writeBuffer(buf, dst);
- }
-
- buf.putInt(requiredRecordBytes);
- buf.putInt(wordCounts.length);
- for (String wc : wordCounts) {
- NIOUtils.putString(wc, buf);
- }
- }
-
- private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
- throws HiveException {
- srcBuf.flip();
- try {
- dst.write(srcBuf);
- } catch (IOException e) {
- throw new HiveException("Exception causes while writing a buffer to file", e);
- }
- srcBuf.clear();
+ protected AbstractProbabilisticTopicModel createModel() {
+ return new IncrementalPLSAModel(topics, alpha, delta);
}
- @Override
- public void close() throws HiveException {
- if (count == 0) {
- this.model = null;
- return;
- }
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- }
- if (iterations > 1) {
- runIterativeTraining(iterations);
- }
- forwardModel();
- this.model = null;
- }
-
- protected final void runIterativeTraining(@Nonnegative final int iterations)
- throws HiveException {
- final ByteBuffer buf = this.inputBuf;
- final NioStatefullSegment dst = this.fileIO;
- assert (buf != null);
- assert (dst != null);
- final long numTrainingExamples = count;
-
- final Reporter reporter = getReporter();
- final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
- "hivemall.plsa.IncrementalPLSA$Counter", "iteration");
-
- try {
- if (dst.getPosition() == 0L) {// run iterations w/o temporary file
- if (buf.position() == 0) {
- return; // no training example
- }
- buf.flip();
-
- int iter = 2;
- float perplexityPrev = Float.MAX_VALUE;
- float perplexity;
- int numTrain;
- for (; iter <= iterations; iter++) {
- perplexity = 0.f;
- numTrain = 0;
-
- reportProgress(reporter);
- setCounterValue(iterCounter, iter);
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
-
- while (buf.remaining() > 0) {
- int recordBytes = buf.getInt();
- assert (recordBytes > 0) : recordBytes;
- int wcLength = buf.getInt();
- final String[] wordCounts = new String[wcLength];
- for (int j = 0; j < wcLength; j++) {
- wordCounts[j] = NIOUtils.getString(buf);
- }
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- perplexity += model.computePerplexity();
- numTrain++;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
- }
- buf.rewind();
-
- // update for remaining samples
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- perplexity += model.computePerplexity();
- numTrain++;
- }
-
- logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
- perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
- if (Math.abs(perplexityPrev - perplexity) < eps) {
- break;
- }
- perplexityPrev = perplexity;
- }
- logger.info("Performed "
- + Math.min(iter, iterations)
- + " iterations of "
- + NumberUtils.formatNumber(numTrainingExamples)
- + " training examples on memory (thus "
- + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
- + " training updates in total) ");
- } else {// read training examples in the temporary file and invoke train for each example
-
- // write training examples in buffer to a temporary file
- if (buf.remaining() > 0) {
- writeBuffer(buf, dst);
- }
- try {
- dst.flush();
- } catch (IOException e) {
- throw new HiveException("Failed to flush a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- if (logger.isInfoEnabled()) {
- File tmpFile = dst.getFile();
- logger.info("Wrote " + numTrainingExamples
- + " records to a temporary file for iterative training: "
- + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
- + ")");
- }
-
- // run iterations
- int iter = 2;
- float perplexityPrev = Float.MAX_VALUE;
- float perplexity;
- int numTrain;
- for (; iter <= iterations; iter++) {
- perplexity = 0.f;
- numTrain = 0;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
-
- setCounterValue(iterCounter, iter);
-
- buf.clear();
- dst.resetPosition();
- while (true) {
- reportProgress(reporter);
- // TODO prefetch
- // writes training examples to a buffer in the temporary file
- final int bytesRead;
- try {
- bytesRead = dst.read(buf);
- } catch (IOException e) {
- throw new HiveException("Failed to read a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- if (bytesRead == 0) { // reached file EOF
- break;
- }
- assert (bytesRead > 0) : bytesRead;
-
- // reads training examples from a buffer
- buf.flip();
- int remain = buf.remaining();
- if (remain < SizeOf.INT) {
- throw new HiveException("Illegal file format was detected");
- }
- while (remain >= SizeOf.INT) {
- int pos = buf.position();
- int recordBytes = buf.getInt() - SizeOf.INT;
- remain -= SizeOf.INT;
- if (remain < recordBytes) {
- buf.position(pos);
- break;
- }
-
- int wcLength = buf.getInt();
- final String[] wordCounts = new String[wcLength];
- for (int j = 0; j < wcLength; j++) {
- wordCounts[j] = NIOUtils.getString(buf);
- }
-
- miniBatch[miniBatchCount] = wordCounts;
- miniBatchCount++;
-
- if (miniBatchCount == miniBatchSize) {
- model.train(miniBatch);
- perplexity += model.computePerplexity();
- numTrain++;
-
- Arrays.fill(miniBatch, null); // clear
- miniBatchCount = 0;
- }
-
- remain -= recordBytes;
- }
- buf.compact();
- }
-
- // update for remaining samples
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- perplexity += model.computePerplexity();
- numTrain++;
- }
-
- logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain);
- perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches
- if (Math.abs(perplexityPrev - perplexity) < eps) {
- break;
- }
- perplexityPrev = perplexity;
- }
- logger.info("Performed "
- + Math.min(iter, iterations)
- + " iterations of "
- + NumberUtils.formatNumber(numTrainingExamples)
- + " training examples on a secondary storage (thus "
- + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
- + " training updates in total)");
- }
- } catch (Throwable e) {
- throw new HiveException("Exception caused in the iterative training", e);
- } finally {
- // delete the temporary file and release resources
- try {
- dst.close(true);
- } catch (IOException e) {
- throw new HiveException("Failed to close a file: "
- + dst.getFile().getAbsolutePath(), e);
- }
- this.inputBuf = null;
- this.fileIO = null;
- }
- }
-
- protected void forwardModel() throws HiveException {
- final IntWritable topicIdx = new IntWritable();
- final Text word = new Text();
- final FloatWritable score = new FloatWritable();
-
- final Object[] forwardObjs = new Object[3];
- forwardObjs[0] = topicIdx;
- forwardObjs[1] = word;
- forwardObjs[2] = score;
-
- for (int k = 0; k < topics; k++) {
- topicIdx.set(k);
-
- final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
- for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
- score.set(e.getKey());
- List<String> words = e.getValue();
- for (int i = 0; i < words.size(); i++) {
- word.set(words.get(i));
- forward(forwardObjs);
- }
- }
- }
-
- logger.info("Forwarded topic words each of " + topics + " topics");
- }
-
- /*
- * For testing:
- */
-
- @VisibleForTesting
- public void closeWithoutModelReset() throws HiveException {
- // launch close(), but not forward & clear model
- if (count == 0) {
- this.model = null;
- return;
- }
- if (miniBatchCount > 0) { // update for remaining samples
- model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
- }
- if (iterations > 1) {
- runIterativeTraining(iterations);
- }
- }
-
- @VisibleForTesting
- double getProbability(String label, int k) {
- return model.getProbability(label, k);
- }
-
- @VisibleForTesting
- SortedMap<Float, List<String>> getTopicWords(int k) {
- return model.getTopicWords(k);
- }
-
- @VisibleForTesting
- float[] getTopicDistribution(@Nonnull String[] doc) {
- return model.getTopicDistribution(doc);
- }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
new file mode 100644
index 0000000..cff076e
--- /dev/null
+++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.topicmodel;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NIOUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
+
+ public static final int DEFAULT_TOPICS = 10;
+
+ // Options
+ protected int topics;
+ protected int iterations;
+ protected double eps;
+ protected int miniBatchSize;
+
+ protected String[][] miniBatch;
+ protected int miniBatchCount;
+
+ protected transient AbstractProbabilisticTopicModel model;
+
+ protected ListObjectInspector wordCountsOI;
+
+ // for iterations
+ protected NioStatefullSegment fileIO;
+ protected ByteBuffer inputBuf;
+
+ private float cumPerplexity;
+
+ public ProbabilisticTopicModelBaseUDTF() {
+ this.topics = DEFAULT_TOPICS;
+ this.iterations = 10;
+ this.eps = 1E-1d;
+ this.miniBatchSize = 128; // if 1, truly online setting
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("k", "topics", true, "The number of topics [default: 10]");
+ opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
+ opts.addOption("eps", "epsilon", true,
+ "Check convergence based on the difference of perplexity [default: 1E-1]");
+ opts.addOption("s", "mini_batch_size", true,
+ "Repeat model updating per mini-batch [default: 128]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+
+ if (argOIs.length >= 2) {
+ String rawArgs = HiveUtils.getConstString(argOIs[1]);
+ cl = parseOptions(rawArgs);
+ this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS);
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
+ if (iterations < 1) {
+ throw new UDFArgumentException(
+ "'-iterations' must be greater than or equals to 1: " + iterations);
+ }
+ this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
+ this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
+ }
+
+ return cl;
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length < 1) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 1 arguments: array<string> words [, const string options]");
+ }
+
+ this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
+ HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
+
+ processOptions(argOIs);
+
+ this.model = null;
+ this.miniBatch = new String[miniBatchSize][];
+ this.miniBatchCount = 0;
+ this.cumPerplexity = 0.f;
+
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("topic");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("word");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ fieldNames.add("score");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ protected abstract AbstractProbabilisticTopicModel createModel();
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ this.model = createModel();
+ }
+
+ final int length = wordCountsOI.getListLength(args[0]);
+ final String[] wordCounts = new String[length];
+ int j = 0;
+ for (int i = 0; i < length; i++) {
+ Object o = wordCountsOI.getListElement(args[0], i);
+ if (o == null) {
+ throw new HiveException("Given feature vector contains invalid elements");
+ }
+ String s = o.toString();
+ wordCounts[j] = s;
+ j++;
+ }
+ if (j == 0) {// avoid empty documents
+ return;
+ }
+
+ model.accumulateDocCount();;
+
+ update(wordCounts);
+
+ recordTrainSampleToTempFile(wordCounts);
+ }
+
+ protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts)
+ throws HiveException {
+ if (iterations == 1) {
+ return;
+ }
+
+ ByteBuffer buf = inputBuf;
+ NioStatefullSegment dst = fileIO;
+
+ if (buf == null) {
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_topicmodel", ".sgmt");
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException("Cannot write a temporary file: "
+ + file.getAbsolutePath());
+ }
+ logger.info("Record training samples to a file: " + file.getAbsolutePath());
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ } catch (Throwable e) {
+ throw new UDFArgumentException(e);
+ }
+ this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB
+ this.fileIO = dst = new NioStatefullSegment(file, false);
+ }
+
+ // wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ...
+ int wcLengthTotal = 0;
+ for (String wc : wordCounts) {
+ if (wc == null) {
+ continue;
+ }
+ wcLengthTotal += wc.length();
+ }
+ int recordBytes = SizeOf.INT + SizeOf.INT * wordCounts.length + wcLengthTotal * SizeOf.CHAR;
+ int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
+
+ int remain = buf.remaining();
+ if (remain < requiredBytes) {
+ writeBuffer(buf, dst);
+ }
+
+ buf.putInt(recordBytes);
+ buf.putInt(wordCounts.length);
+ for (String wc : wordCounts) {
+ NIOUtils.putString(wc, buf);
+ }
+ }
+
+ private void update(@Nonnull final String[] wordCounts) {
+ miniBatch[miniBatchCount] = wordCounts;
+ miniBatchCount++;
+
+ if (miniBatchCount == miniBatchSize) {
+ train();
+ }
+ }
+
+ protected void train() {
+ if (miniBatchCount == 0) {
+ return;
+ }
+
+ model.train(miniBatch);
+
+ this.cumPerplexity += model.computePerplexity();
+
+ Arrays.fill(miniBatch, null); // clear
+ miniBatchCount = 0;
+ }
+
+ private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst)
+ throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ @Override
+ public void close() throws HiveException {
+ finalizeTraining();
+ forwardModel();
+ this.model = null;
+ }
+
+ @VisibleForTesting
+ void finalizeTraining() throws HiveException {
+ if (model.getDocCount() == 0L) {
+ this.model = null;
+ return;
+ }
+ if (miniBatchCount > 0) { // update for remaining samples
+ model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
+ }
+ if (iterations > 1) {
+ runIterativeTraining(iterations);
+ }
+ }
+
+ protected final void runIterativeTraining(@Nonnegative final int iterations)
+ throws HiveException {
+ final ByteBuffer buf = this.inputBuf;
+ final NioStatefullSegment dst = this.fileIO;
+ assert (buf != null);
+ assert (dst != null);
+ final long numTrainingExamples = model.getDocCount();
+
+ long numTrain = numTrainingExamples / miniBatchSize;
+ if (numTrainingExamples % miniBatchSize != 0L) {
+ numTrain++;
+ }
+
+ final Reporter reporter = getReporter();
+ final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+ "hivemall.topicmodel.ProbabilisticTopicModel$Counter", "iteration");
+
+ try {
+ if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+ if (buf.position() == 0) {
+ return; // no training example
+ }
+ buf.flip();
+
+ int iter = 2;
+ float perplexity = cumPerplexity / numTrain;
+ float perplexityPrev;
+ for (; iter <= iterations; iter++) {
+ perplexityPrev = perplexity;
+ cumPerplexity = 0.f;
+
+ reportProgress(reporter);
+ setCounterValue(iterCounter, iter);
+
+ while (buf.remaining() > 0) {
+ int recordBytes = buf.getInt();
+ assert (recordBytes > 0) : recordBytes;
+ int wcLength = buf.getInt();
+ final String[] wordCounts = new String[wcLength];
+ for (int j = 0; j < wcLength; j++) {
+ wordCounts[j] = NIOUtils.getString(buf);
+ }
+ update(wordCounts);
+ }
+ buf.rewind();
+
+ // mean perplexity over `numTrain` mini-batches
+ perplexity = cumPerplexity / numTrain;
+ logger.info("Mean perplexity over mini-batches: " + perplexity);
+ if (Math.abs(perplexityPrev - perplexity) < eps) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + Math.min(iter, iterations)
+ + " iterations of "
+ + NumberUtils.formatNumber(numTrainingExamples)
+ + " training examples on memory (thus "
+ + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+ + " training updates in total) ");
+ } else {// read training examples in the temporary file and invoke train for each example
+ // write training examples in buffer to a temporary file
+ if (buf.remaining() > 0) {
+ writeBuffer(buf, dst);
+ }
+ try {
+ dst.flush();
+ } catch (IOException e) {
+ throw new HiveException("Failed to flush a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (logger.isInfoEnabled()) {
+ File tmpFile = dst.getFile();
+ logger.info("Wrote " + numTrainingExamples
+ + " records to a temporary file for iterative training: "
+ + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+ + ")");
+ }
+
+ // run iterations
+ int iter = 2;
+ float perplexity = cumPerplexity / numTrain;
+ float perplexityPrev;
+ for (; iter <= iterations; iter++) {
+ perplexityPrev = perplexity;
+ cumPerplexity = 0.f;
+
+ setCounterValue(iterCounter, iter);
+
+ buf.clear();
+ dst.resetPosition();
+ while (true) {
+ reportProgress(reporter);
+ // TODO prefetch
+ // writes training examples to a buffer in the temporary file
+ final int bytesRead;
+ try {
+ bytesRead = dst.read(buf);
+ } catch (IOException e) {
+ throw new HiveException("Failed to read a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (bytesRead == 0) { // reached file EOF
+ break;
+ }
+ assert (bytesRead > 0) : bytesRead;
+
+ // reads training examples from a buffer
+ buf.flip();
+ int remain = buf.remaining();
+ if (remain < SizeOf.INT) {
+ throw new HiveException("Illegal file format was detected");
+ }
+ while (remain >= SizeOf.INT) {
+ int pos = buf.position();
+ int recordBytes = buf.getInt();
+ remain -= SizeOf.INT;
+ if (remain < recordBytes) {
+ buf.position(pos);
+ break;
+ }
+
+ int wcLength = buf.getInt();
+ final String[] wordCounts = new String[wcLength];
+ for (int j = 0; j < wcLength; j++) {
+ wordCounts[j] = NIOUtils.getString(buf);
+ }
+ update(wordCounts);
+
+ remain -= recordBytes;
+ }
+ buf.compact();
+ }
+
+ // mean perplexity over `numTrain` mini-batches
+ perplexity = cumPerplexity / numTrain;
+ logger.info("Mean perplexity over mini-batches: " + perplexity);
+ if (Math.abs(perplexityPrev - perplexity) < eps) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + Math.min(iter, iterations)
+ + " iterations of "
+ + NumberUtils.formatNumber(numTrainingExamples)
+ + " training examples on a secondary storage (thus "
+ + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations))
+ + " training updates in total)");
+ }
+ } catch (Throwable e) {
+ throw new HiveException("Exception caused in the iterative training", e);
+ } finally {
+ // delete the temporary file and release resources
+ try {
+ dst.close(true);
+ } catch (IOException e) {
+ throw new HiveException("Failed to close a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ this.inputBuf = null;
+ this.fileIO = null;
+ }
+ }
+
+ protected void forwardModel() throws HiveException {
+ final IntWritable topicIdx = new IntWritable();
+ final Text word = new Text();
+ final FloatWritable score = new FloatWritable();
+
+ final Object[] forwardObjs = new Object[3];
+ forwardObjs[0] = topicIdx;
+ forwardObjs[1] = word;
+ forwardObjs[2] = score;
+
+ for (int k = 0; k < topics; k++) {
+ topicIdx.set(k);
+
+ final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
+ for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
+ score.set(e.getKey());
+ List<String> words = e.getValue();
+ for (int i = 0; i < words.size(); i++) {
+ word.set(words.get(i));
+ forward(forwardObjs);
+ }
+ }
+ }
+
+ logger.info("Forwarded topic words each of " + topics + " topics");
+ }
+
+ @VisibleForTesting
+ float getWordScore(String label, int k) {
+ return model.getWordScore(label, k);
+ }
+
+ @VisibleForTesting
+ SortedMap<Float, List<String>> getTopicWords(int k) {
+ return model.getTopicWords(k);
+ }
+
+ @VisibleForTesting
+ float[] getTopicDistribution(@Nonnull String[] doc) {
+ return model.getTopicDistribution(doc);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
index 79be3a7..96bbe64 100644
--- a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java
@@ -110,10 +110,10 @@ public class IncrementalPLSAModelTest {
}
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
- model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+ model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
- model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+ model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
}
@Test
@@ -177,10 +177,10 @@ public class IncrementalPLSAModelTest {
}
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
- model.getProbability("vegetables", k1) > model.getProbability("flu", k1));
+ model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
- model.getProbability("avocados", k2) > model.getProbability("healthy", k2));
+ model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
}
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
index a934ba3..4cbb668 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java
@@ -53,7 +53,7 @@ public class LDAUDTFTest {
udtf.process(new Object[] {Arrays.asList(doc1)});
udtf.process(new Object[] {Arrays.asList(doc2)});
- udtf.closeWithoutModelReset();
+ udtf.finalizeTraining();
SortedMap<Float, List<String>> topicWords;
@@ -92,10 +92,10 @@ public class LDAUDTFTest {
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
- udtf.getLambda("vegetables", k1) > udtf.getLambda("flu", k1));
+ udtf.getWordScore("vegetables", k1) > udtf.getWordScore("flu", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
- udtf.getLambda("avocados", k2) > udtf.getLambda("healthy", k2));
+ udtf.getWordScore("avocados", k2) > udtf.getWordScore("healthy", k2));
}
@Test
@@ -106,7 +106,7 @@ public class LDAUDTFTest {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3")};
+ "-topics 2 -num_docs 2 -s 1 -iter 32 -eps 1e-3 -mini_batch_size 1")};
udtf.initialize(argOIs);
@@ -116,7 +116,7 @@ public class LDAUDTFTest {
udtf.process(new Object[] {Arrays.asList(doc1)});
udtf.process(new Object[] {Arrays.asList(doc2)});
- udtf.closeWithoutModelReset();
+ udtf.finalizeTraining();
SortedMap<Float, List<String>> topicWords;
@@ -155,10 +155,10 @@ public class LDAUDTFTest {
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `野菜` SHOULD be more suitable topic word than `インフルエンザ` in the topic",
- udtf.getLambda("野菜", k1) > udtf.getLambda("インフルエンザ", k1));
+ udtf.getWordScore("野菜", k1) > udtf.getWordScore("インフルエンザ", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `アボカド` SHOULD be more suitable topic word than `健康` in the topic",
- udtf.getLambda("アボカド", k2) > udtf.getLambda("健康", k2));
+ udtf.getWordScore("アボカド", k2) > udtf.getWordScore("健康", k2));
}
private static void println(String msg) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/0495ffad/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
index 5b0a8c2..68f251a 100644
--- a/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
+++ b/core/src/test/java/hivemall/topicmodel/OnlineLDAModelTest.java
@@ -108,10 +108,10 @@ public class OnlineLDAModelTest {
}
Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), "
+ "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic",
- model.getLambda("vegetables", k1) > model.getLambda("flu", k1));
+ model.getWordScore("vegetables", k1) > model.getWordScore("flu", k1));
Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), "
+ "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic",
- model.getLambda("avocados", k2) > model.getLambda("healthy", k2));
+ model.getWordScore("avocados", k2) > model.getWordScore("healthy", k2));
}
@Test