You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2014/01/26 04:50:55 UTC
svn commit: r1561440 - in /mahout/trunk: ./
core/src/main/java/org/apache/mahout/clustering/classify/
core/src/test/java/org/apache/mahout/clustering/classify/
Author: smarthi
Date: Sun Jan 26 03:50:55 2014
New Revision: 1561440
URL: http://svn.apache.org/r1561440
Log:
MAHOUT-1410: clusteredPoints do not contain a vector id
Modified:
mahout/trunk/CHANGELOG
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Sun Jan 26 03:50:55 2014
@@ -2,6 +2,8 @@ Mahout Change Log
Release 0.9 - unreleased
+ MAHOUT-1410: clusteredPoints do not contain a vector id (smarthi, Andrew Musselman)
+
MAHOUT-1409: MatrixVectorView has index check error (tdunning)
MAHOUT-1402: Zero clusters using streaming k-means option in cluster-reuters.sh (smarthi)
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Sun Jan 26 03:50:55 2014
@@ -42,11 +42,13 @@ import org.apache.mahout.clustering.Clus
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.iterator.ClusteringPolicy;
import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
@@ -186,7 +188,11 @@ public final class ClusterClassification
* @param output
* the path to store classified data
* @param clusterClassificationThreshold
+ * the threshold value of probability distribution function from 0.0
+ * to 1.0. Any vector with pdf less that this threshold will not be
+ * classified for the cluster
* @param emitMostLikely
+ * emit the vectors with the max pdf values per cluster
* @throws IOException
*/
private static void selectCluster(Path input, List<Cluster> clusterModels, ClusterClassifier clusterClassifier,
@@ -194,11 +200,20 @@ public final class ClusterClassification
Configuration conf = new Configuration();
SequenceFile.Writer writer = new SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
"part-m-" + 0), IntWritable.class, WeightedPropertyVectorWritable.class);
- for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(input, PathType.LIST,
+ for (Pair<Writable, VectorWritable> vw : new SequenceFileDirIterable<Writable, VectorWritable>(input, PathType.LIST,
PathFilters.logsCRCFilter(), conf)) {
- Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+ Writable key = vw.getFirst();
+ Vector vector = vw.getSecond().get();
+ if (!(vector instanceof NamedVector)) {
+ if (key instanceof Text) {
+ vector = new NamedVector(vector, key.toString());
+ } else if (key instanceof IntWritable) {
+ vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
+ }
+ }
+ Vector pdfPerCluster = clusterClassifier.classify(vector);
if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
- classifyAndWrite(clusterModels, clusterClassificationThreshold, emitMostLikely, writer, vw, pdfPerCluster);
+ classifyAndWrite(clusterModels, clusterClassificationThreshold, emitMostLikely, writer, new VectorWritable(vector), pdfPerCluster);
}
}
writer.close();
@@ -209,8 +224,9 @@ public final class ClusterClassification
Map<Text, Text> props = Maps.newHashMap();
if (emitMostLikely) {
int maxValueIndex = pdfPerCluster.maxValueIndex();
- WeightedPropertyVectorWritable wpvw = new WeightedPropertyVectorWritable(pdfPerCluster.maxValue(), vw.get(), props);
- write(clusterModels, writer, wpvw, maxValueIndex);
+ WeightedPropertyVectorWritable weightedPropertyVectorWritable =
+ new WeightedPropertyVectorWritable(pdfPerCluster.maxValue(), vw.get(), props);
+ write(clusterModels, writer, weightedPropertyVectorWritable, maxValueIndex);
} else {
writeAllAboveThreshold(clusterModels, clusterClassificationThreshold, writer, vw, pdfPerCluster);
}
@@ -218,19 +234,23 @@ public final class ClusterClassification
private static void writeAllAboveThreshold(List<Cluster> clusterModels, Double clusterClassificationThreshold,
SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster) throws IOException {
+ Map<Text, Text> props = Maps.newHashMap();
for (Element pdf : pdfPerCluster.nonZeroes()) {
if (pdf.get() >= clusterClassificationThreshold) {
- WeightedVectorWritable wvw = new WeightedVectorWritable(pdf.get(), vw.get());
+ WeightedPropertyVectorWritable wvw = new WeightedPropertyVectorWritable(pdf.get(), vw.get(), props);
int clusterIndex = pdf.index();
write(clusterModels, writer, wvw, clusterIndex);
}
}
}
- private static void write(List<Cluster> clusterModels, SequenceFile.Writer writer, WeightedVectorWritable wvw,
+ private static void write(List<Cluster> clusterModels, SequenceFile.Writer writer,
+ WeightedPropertyVectorWritable weightedPropertyVectorWritable,
int maxValueIndex) throws IOException {
Cluster cluster = clusterModels.get(maxValueIndex);
- writer.append(new IntWritable(cluster.getId()), wvw);
+ double d = Math.sqrt(cluster.getCenter().getDistanceSquared(weightedPropertyVectorWritable.getVector()));
+ weightedPropertyVectorWritable.getProperties().put(new Text("distance"), new Text(Double.toString(d)));
+ writer.append(new IntWritable(cluster.getId()), weightedPropertyVectorWritable);
}
/**
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Sun Jan 26 03:50:55 2014
@@ -39,6 +39,7 @@ import org.apache.mahout.clustering.iter
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
@@ -83,13 +84,21 @@ public class ClusterClassificationMapper
protected void map(WritableComparable<?> key, VectorWritable vw, Context context)
throws IOException, InterruptedException {
if (!clusterModels.isEmpty()) {
- Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+ Vector vector = vw.get();
+ if (!(vector instanceof NamedVector)) {
+ if (key instanceof Text) {
+ vector = new NamedVector(vector, key.toString());
+ } else if (key instanceof IntWritable) {
+ vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
+ }
+ }
+ Vector pdfPerCluster = clusterClassifier.classify(vector);
if (shouldClassify(pdfPerCluster)) {
if (emitMostLikely) {
int maxValueIndex = pdfPerCluster.maxValueIndex();
- write(vw, context, maxValueIndex, 1.0);
+ write(new VectorWritable(vector), context, maxValueIndex, 1.0);
} else {
- writeAllAboveThreshold(vw, context, pdfPerCluster);
+ writeAllAboveThreshold(new VectorWritable(vector), context, pdfPerCluster);
}
}
}
@@ -109,9 +118,9 @@ public class ClusterClassificationMapper
throws IOException, InterruptedException {
Cluster cluster = clusterModels.get(clusterIndex);
clusterId.set(cluster.getId());
- double d = cluster.getCenter().getDistanceSquared(vw.get());
+ double d = Math.sqrt(cluster.getCenter().getDistanceSquared(vw.get()));
Map<Text, Text> props = Maps.newHashMap();
- props.put(new Text("distance-squared"), new Text(Double.toString(d)));
+ props.put(new Text("distance"), new Text(Double.toString(d)));
context.write(clusterId, new WeightedPropertyVectorWritable(weight, vw.get(), props));
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1561440&r1=1561439&r2=1561440&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java Sun Jan 26 03:50:55 2014
@@ -37,6 +37,7 @@ import org.apache.mahout.common.HadoopUt
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -235,9 +236,15 @@ public class ClusterClassificationDriver
} else {
singletonCnt++;
assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size());
- Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
- reference.contains(vList.get(0).asFormatString()));
- reference.remove(vList.get(0).asFormatString());
+ if (vList.get(0) instanceof NamedVector) {
+ Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(),
+ reference.contains(((NamedVector) vList.get(0)).getDelegate().asFormatString()));
+ reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString());
+ } else if (vList.get(0) instanceof RandomAccessSparseVector) {
+ Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
+ reference.contains(vList.get(0).asFormatString()));
+ reference.remove(vList.get(0).asFormatString());
+ }
}
}
Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt);