You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2014/01/26 04:50:55 UTC

svn commit: r1561440 - in /mahout/trunk: ./ core/src/main/java/org/apache/mahout/clustering/classify/ core/src/test/java/org/apache/mahout/clustering/classify/

Author: smarthi
Date: Sun Jan 26 03:50:55 2014
New Revision: 1561440

URL: http://svn.apache.org/r1561440
Log:
MAHOUT-1410: clusteredPoints do not contain a vector id

Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Sun Jan 26 03:50:55 2014
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 0.9 - unreleased
 
+  MAHOUT-1410: clusteredPoints do not contain a vector id (smarthi, Andrew Musselman)
+
   MAHOUT-1409: MatrixVectorView has index check error (tdunning)
 
   MAHOUT-1402: Zero clusters using streaming k-means option in cluster-reuters.sh (smarthi)

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Sun Jan 26 03:50:55 2014
@@ -42,11 +42,13 @@ import org.apache.mahout.clustering.Clus
 import org.apache.mahout.clustering.iterator.ClusterWritable;
 import org.apache.mahout.clustering.iterator.ClusteringPolicy;
 import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.Vector.Element;
 import org.apache.mahout.math.VectorWritable;
@@ -186,7 +188,11 @@ public final class ClusterClassification
    * @param output
    *          the path to store classified data
    * @param clusterClassificationThreshold
+   *          the threshold value of probability distribution function from 0.0
+   *          to 1.0. Any vector with pdf less that this threshold will not be
+   *          classified for the cluster
    * @param emitMostLikely
+   *          emit the vectors with the max pdf values per cluster
    * @throws IOException
    */
   private static void selectCluster(Path input, List<Cluster> clusterModels, ClusterClassifier clusterClassifier,
@@ -194,11 +200,20 @@ public final class ClusterClassification
     Configuration conf = new Configuration();
     SequenceFile.Writer writer = new SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
         "part-m-" + 0), IntWritable.class, WeightedPropertyVectorWritable.class);
-    for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(input, PathType.LIST,
+    for (Pair<Writable, VectorWritable> vw : new SequenceFileDirIterable<Writable, VectorWritable>(input, PathType.LIST,
         PathFilters.logsCRCFilter(), conf)) {
-      Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+      Writable key = vw.getFirst();
+      Vector vector = vw.getSecond().get();
+      if (!(vector instanceof NamedVector)) {
+        if (key instanceof Text) {
+          vector = new NamedVector(vector, key.toString());
+        } else if (key instanceof IntWritable) {
+          vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
+        }
+      }
+      Vector pdfPerCluster = clusterClassifier.classify(vector);
       if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
-        classifyAndWrite(clusterModels, clusterClassificationThreshold, emitMostLikely, writer, vw, pdfPerCluster);
+        classifyAndWrite(clusterModels, clusterClassificationThreshold, emitMostLikely, writer, new VectorWritable(vector), pdfPerCluster);
       }
     }
     writer.close();
@@ -209,8 +224,9 @@ public final class ClusterClassification
     Map<Text, Text> props = Maps.newHashMap();
     if (emitMostLikely) {
       int maxValueIndex = pdfPerCluster.maxValueIndex();
-      WeightedPropertyVectorWritable wpvw = new WeightedPropertyVectorWritable(pdfPerCluster.maxValue(), vw.get(), props);
-      write(clusterModels, writer, wpvw, maxValueIndex);
+      WeightedPropertyVectorWritable weightedPropertyVectorWritable =
+          new WeightedPropertyVectorWritable(pdfPerCluster.maxValue(), vw.get(), props);
+      write(clusterModels, writer, weightedPropertyVectorWritable, maxValueIndex);
     } else {
       writeAllAboveThreshold(clusterModels, clusterClassificationThreshold, writer, vw, pdfPerCluster);
     }
@@ -218,19 +234,23 @@ public final class ClusterClassification
   
   private static void writeAllAboveThreshold(List<Cluster> clusterModels, Double clusterClassificationThreshold,
       SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster) throws IOException {
+    Map<Text, Text> props = Maps.newHashMap();
     for (Element pdf : pdfPerCluster.nonZeroes()) {
       if (pdf.get() >= clusterClassificationThreshold) {
-        WeightedVectorWritable wvw = new WeightedVectorWritable(pdf.get(), vw.get());
+        WeightedPropertyVectorWritable wvw = new WeightedPropertyVectorWritable(pdf.get(), vw.get(), props);
         int clusterIndex = pdf.index();
         write(clusterModels, writer, wvw, clusterIndex);
       }
     }
   }
 
-  private static void write(List<Cluster> clusterModels, SequenceFile.Writer writer, WeightedVectorWritable wvw,
+  private static void write(List<Cluster> clusterModels, SequenceFile.Writer writer,
+      WeightedPropertyVectorWritable weightedPropertyVectorWritable,
       int maxValueIndex) throws IOException {
     Cluster cluster = clusterModels.get(maxValueIndex);
-    writer.append(new IntWritable(cluster.getId()), wvw);
+    double d = Math.sqrt(cluster.getCenter().getDistanceSquared(weightedPropertyVectorWritable.getVector()));
+    weightedPropertyVectorWritable.getProperties().put(new Text("distance"), new Text(Double.toString(d)));
+    writer.append(new IntWritable(cluster.getId()), weightedPropertyVectorWritable);
   }
   
   /**

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Sun Jan 26 03:50:55 2014
@@ -39,6 +39,7 @@ import org.apache.mahout.clustering.iter
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.Vector.Element;
 import org.apache.mahout.math.VectorWritable;
@@ -83,13 +84,21 @@ public class ClusterClassificationMapper
   protected void map(WritableComparable<?> key, VectorWritable vw, Context context)
     throws IOException, InterruptedException {
     if (!clusterModels.isEmpty()) {
-      Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+      Vector vector = vw.get();
+      if (!(vector instanceof NamedVector)) {
+        if (key instanceof Text) {
+          vector = new NamedVector(vector, key.toString());
+        } else if (key instanceof IntWritable) {
+          vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
+        }
+      }
+      Vector pdfPerCluster = clusterClassifier.classify(vector);
       if (shouldClassify(pdfPerCluster)) {
         if (emitMostLikely) {
           int maxValueIndex = pdfPerCluster.maxValueIndex();
-          write(vw, context, maxValueIndex, 1.0);
+          write(new VectorWritable(vector), context, maxValueIndex, 1.0);
         } else {
-          writeAllAboveThreshold(vw, context, pdfPerCluster);
+          writeAllAboveThreshold(new VectorWritable(vector), context, pdfPerCluster);
         }
       }
     }
@@ -109,9 +118,9 @@ public class ClusterClassificationMapper
     throws IOException, InterruptedException {
     Cluster cluster = clusterModels.get(clusterIndex);
     clusterId.set(cluster.getId());
-    double d = cluster.getCenter().getDistanceSquared(vw.get());
+    double d = Math.sqrt(cluster.getCenter().getDistanceSquared(vw.get()));
     Map<Text, Text> props = Maps.newHashMap();
-    props.put(new Text("distance-squared"), new Text(Double.toString(d)));
+    props.put(new Text("distance"), new Text(Double.toString(d)));
     context.write(clusterId, new WeightedPropertyVectorWritable(weight, vw.get(), props));
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java Sun Jan 26 03:50:55 2014
@@ -37,6 +37,7 @@ import org.apache.mahout.common.HadoopUt
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -235,9 +236,15 @@ public class ClusterClassificationDriver
       } else {
         singletonCnt++;
         assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size());
-        Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
-                          reference.contains(vList.get(0).asFormatString()));
-        reference.remove(vList.get(0).asFormatString());
+        if (vList.get(0) instanceof NamedVector) {
+          Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(),
+                  reference.contains(((NamedVector)  vList.get(0)).getDelegate().asFormatString()));
+          reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString());
+        } else if (vList.get(0) instanceof RandomAccessSparseVector) {
+          Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
+                  reference.contains(vList.get(0).asFormatString()));
+          reference.remove(vList.get(0).asFormatString());
+        }
       }
     }
     Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt);