You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/03/01 22:50:25 UTC

svn commit: r917742 - in /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda: LDADriver.java LDAInference.java

Author: robinanil
Date: Mon Mar  1 21:50:25 2010
New Revision: 917742

URL: http://svn.apache.org/viewvc?rev=917742&view=rev
Log:
Cleanup for 0.3 release. LDA was using HashMap, changed it to OpenIntIntHashmap for better performance. plus function was taking more time due to excessive clone() repeated assign() was taking time due to check of f.apply(0,val). Inlined the mutable addition operation in eStep

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=917742&r1=917741&r2=917742&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Mon Mar  1 21:50:25 2010
@@ -52,16 +52,13 @@
  */
 public final class LDADriver {
   
-  static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
-  
+  static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";  
   static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
-  static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
-  
+  static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";  
   static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
   
   static final int LOG_LIKELIHOOD_KEY = -2;
   static final int TOPIC_SUM_KEY = -1;
-  
   static final double OVERALL_CONVERGENCE = 1.0E-5;
   
   private static final Logger log = LoggerFactory.getLogger(LDADriver.class);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=917742&r1=917741&r2=917742&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Mon Mar  1 21:50:25 2010
@@ -17,9 +17,7 @@
 
 package org.apache.mahout.clustering.lda;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
 
 import org.apache.commons.math.special.Gamma;
 import org.apache.mahout.math.DenseMatrix;
@@ -27,6 +25,7 @@
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.BinaryFunction;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
 
 /**
  * Class for performing infererence on a document, which involves computing (an approximation to)
@@ -50,14 +49,14 @@
     private final Vector wordCounts;
     private final Vector gamma; // p(topic)
     private final Matrix mphi; // log p(columnMap(w)|t)
-    private final Map<Integer,Integer> columnMap; // maps words into the matrix's column map
+    private final OpenIntIntHashMap columnMap; // maps words into the matrix's column map
     public final double logLikelihood;
     
     public double phi(int k, int w) {
       return mphi.getQuick(k, columnMap.get(w));
     }
     
-    InferredDocument(Vector wordCounts, Vector gamma, Map<Integer,Integer> columnMap, Matrix phi, double ll) {
+    InferredDocument(Vector wordCounts, Vector gamma, OpenIntIntHashMap columnMap, Matrix phi, double ll) {
       this.wordCounts = wordCounts;
       this.gamma = gamma;
       this.mphi = phi;
@@ -94,7 +93,7 @@
     double digammaSumGamma = digamma(gamma.zSum());
     digammaGamma = digammaGamma.plus(-digammaSumGamma);
     
-    Map<Integer,Integer> columnMap = new HashMap<Integer,Integer>();
+    OpenIntIntHashMap columnMap = new OpenIntIntHashMap();
     
     int iteration = 0;
     
@@ -145,7 +144,7 @@
   private final LDAState state;
   
   private double computeLikelihood(Vector wordCounts,
-                                   Map<Integer,Integer> columnMap,
+                                   OpenIntIntHashMap columnMap,
                                    Matrix phi,
                                    Vector gamma,
                                    Vector digammaGamma) {
@@ -174,8 +173,8 @@
       // now for each topic:
       for (int k = 0; k < state.numTopics; k++) {
         double llPart = 0.0;
-        llPart += Math.exp(phi.get(k, mapping))
-                  * (digammaGamma.get(k) - phi.get(k, mapping) + state.logProbWordGivenTopic(w, k));
+        llPart += Math.exp(phi.getQuick(k, mapping))
+                  * (digammaGamma.get(k) - phi.getQuick(k, mapping) + state.logProbWordGivenTopic(w, k));
         
         ll += llPart * n;
         
@@ -202,7 +201,10 @@
       assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
       assert !Double.isNaN(digammaGamma.get(k));
     }
-    return phi.plus(-phiTotal); // log normalize
+    for (int i = 0; i < state.numTopics; i++) {
+      phi.setQuick(i, phi.getQuick(i) - phiTotal);// log normalize
+    }
+    return phi;
   }
   
   private static Vector digamma(Vector v) {