You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/04/08 02:08:01 UTC
svn commit: r931732 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
core/src/test/java/org/apache/m...
Author: jeastman
Date: Thu Apr 8 00:08:00 2010
New Revision: 931732
URL: http://svn.apache.org/viewvc?rev=931732&view=rev
Log:
MAHOUT-270: completed update to ClusterDumper and added a unit test of Canopy, KMeans and Dirichlet all using the ClusterDumper basic functionality. More to polish but this is working
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
Removed:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java?rev=931732&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java Thu Apr 8 00:08:00 2010
@@ -0,0 +1,65 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Implementations of this interface have a printable representation and certain
+ * attributes that are common across all clustering implementations
+ *
+ */
+public interface Cluster {
+
+ /**
+ * Get the id of the Cluster
+ *
+ * @return a unique integer
+ */
+ public int getId();
+
+ /**
+ * Get the "center" of the Cluster as a Vector
+ *
+ * @return a Vector
+ */
+ public Vector getCenter();
+
+ /**
+ * Get an integer denoting the number of points observed by this cluster
+ * @return an integer
+ */
+ public int getNumPoints();
+
+ /**
+ * Produce a custom, human-friendly, printable representation of the Cluster.
+ *
+ * @param bindings
+ * an optional String[] containing labels used to format the primary Vector/s of this
+ * implementation.
+ * @return a String
+ */
+ String asFormatString(String[] bindings);
+
+ /**
+ * Produce a textual representation of the Cluster using Json format. (Label bindings are transient and not part
+ * of the Json representation)
+ *
+ * @return a Json String
+ */
+ String asJsonString();
+
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java Thu Apr 8 00:08:00 2010
@@ -34,15 +34,20 @@ import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
-public abstract class ClusterBase implements Writable, Printable {
+/**
+ * ClusterBase is an abstract base class class for several clustering implementations
+ * that share common implementations of various atttributes
+ *
+ */
+public abstract class ClusterBase implements Writable, Cluster {
// this cluster's clusterId
- private int id;
+ int id;
// the current cluster center
- private Vector center = new RandomAccessSparseVector(0);
+ Vector center = new RandomAccessSparseVector(0);
// the number of points in the cluster
- private int numPoints;
+ int numPoints;
// the Vector total of all points added to the cluster
private Vector pointTotal;
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java Thu Apr 8 00:08:00 2010
@@ -22,7 +22,7 @@ import java.io.IOException;
import java.lang.reflect.Type;
import org.apache.hadoop.io.Writable;
-import org.apache.mahout.clustering.Printable;
+import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.math.Vector;
@@ -30,7 +30,7 @@ import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
-public class DirichletCluster<O> implements Writable, Printable {
+public class DirichletCluster<O> implements Writable, Cluster {
@Override
public void readFields(DataInput in) throws IOException {
@@ -114,5 +114,17 @@ public class DirichletCluster<O> impleme
Gson gson = builder.create();
return gson.toJson(this, clusterType);
}
+
+ public int getId() {
+ return model.getId();
+ }
+
+ public Vector getCenter() {
+ return model.getCenter();
+ }
+
+ public int getNumPoints() {
+ return model.getNumPoints();
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Thu Apr 8 00:08:00 2010
@@ -36,6 +36,8 @@ public class AsymmetricSampledNormalMode
private static final double sqrt2pi = Math.sqrt(2.0 * Math.PI);
+ private int id;
+
// the parameters
private Vector mean;
@@ -147,7 +149,6 @@ public class AsymmetricSampledNormalMode
return asFormatString(null);
}
- @Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder(50);
buf.append("asnm{n=").append(s0).append(" m=");
@@ -185,11 +186,25 @@ public class AsymmetricSampledNormalMode
VectorWritable.writeVector(out, s2);
}
- @Override
public String asJsonString() {
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
Gson gson = builder.create();
return gson.toJson(this, modelType);
}
+
+ @Override
+ public Vector getCenter() {
+ return mean;
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ @Override
+ public int getNumPoints() {
+ return s0;
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Thu Apr 8 00:08:00 2010
@@ -33,68 +33,73 @@ import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
public class L1Model implements Model<VectorWritable> {
-
+
private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
-
+
+ private int id;
+
public L1Model() {
super();
}
-
+
public L1Model(Vector v) {
observed = v.like();
coefficients = v;
}
-
+
private Vector coefficients;
-
+
private int count = 0;
-
+
private Vector observed;
-
- private static final Type modelType = new TypeToken<Model<Vector>>() { }.getType();
-
+
+ private static final Type modelType = new TypeToken<Model<Vector>>() {
+ }.getType();
+
@Override
public void computeParameters() {
coefficients = observed.divide(count);
}
-
+
@Override
public int count() {
return count;
}
-
+
@Override
public void observe(VectorWritable x) {
count++;
x.get().addTo(observed);
}
-
+
@Override
public double pdf(VectorWritable x) {
return Math.exp(-L1Model.measure.distance(x.get(), coefficients));
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
+ count = in.readInt();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
coefficients = temp.get();
}
-
+
@Override
public void write(DataOutput out) throws IOException {
+ out.writeInt(count);
VectorWritable.writeVector(out, coefficients);
}
-
+
public L1Model sample() {
return new L1Model(coefficients.clone());
}
-
+
@Override
public String toString() {
return asFormatString(null);
}
-
+
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder();
@@ -105,7 +110,7 @@ public class L1Model implements Model<Ve
buf.append('}');
return buf.toString();
}
-
+
/*
* (non-Javadoc)
*
@@ -118,5 +123,20 @@ public class L1Model implements Model<Ve
Gson gson = builder.create();
return gson.toJson(this, modelType);
}
-
+
+ @Override
+ public Vector getCenter() {
+ return coefficients;
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ @Override
+ public int getNumPoints() {
+ return count;
+ }
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java Thu Apr 8 00:08:00 2010
@@ -18,13 +18,13 @@
package org.apache.mahout.clustering.dirichlet.models;
import org.apache.hadoop.io.Writable;
-import org.apache.mahout.clustering.Printable;
+import org.apache.mahout.clustering.Cluster;
/**
* A model is a probability distribution over observed data points and allows the probability of any data
* point to be computed.
*/
-public interface Model<O> extends Writable, Printable {
+public interface Model<O> extends Writable, Cluster {
/**
* Observe the given observation, retaining information about it
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Thu Apr 8 00:08:00 2010
@@ -37,6 +37,8 @@ public class NormalModel implements Mode
private static final double sqrt2pi = Math.sqrt(2.0 * Math.PI);
+ private int id;
+
// the parameters
private Vector mean;
@@ -172,4 +174,19 @@ public class NormalModel implements Mode
Gson gson = builder.create();
return gson.toJson(this, modelType);
}
+
+ @Override
+ public Vector getCenter() {
+ return mean;
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ @Override
+ public int getNumPoints() {
+ return s0;
+ }
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java Thu Apr 8 00:08:00 2010
@@ -27,7 +27,7 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.NormalModel;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
-import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
@@ -50,7 +50,7 @@ public class TestPrintableInterface exte
public void testDirichletNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable model = new NormalModel(m, 0.75);
+ Cluster model = new NormalModel(m, 0.75);
String format = model.asFormatString(null);
assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
String json = model.asJsonString();
@@ -64,7 +64,7 @@ public class TestPrintableInterface exte
public void testDirichletSampledNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable model = new SampledNormalModel(m, 0.75);
+ Cluster model = new SampledNormalModel(m, 0.75);
String format = model.asFormatString(null);
assertEquals("format", "snm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
String json = model.asJsonString();
@@ -78,7 +78,7 @@ public class TestPrintableInterface exte
public void testDirichletASNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable model = new AsymmetricSampledNormalModel(m, m);
+ Cluster model = new AsymmetricSampledNormalModel(m, m);
String format = model.asFormatString(null);
assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format);
String json = model.asJsonString();
@@ -92,7 +92,7 @@ public class TestPrintableInterface exte
public void testDirichletL1Model() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable model = new L1Model(m);
+ Cluster model = new L1Model(m);
String format = model.asFormatString(null);
assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
String json = model.asJsonString();
@@ -107,7 +107,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
NormalModel model = new NormalModel(m, 0.75);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
}
@@ -116,7 +116,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
NormalModel model = new NormalModel(m, 0.75);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
@@ -130,7 +130,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format);
}
@@ -139,7 +139,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();
GsonBuilder builder = new GsonBuilder();
@@ -154,7 +154,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
L1Model model = new L1Model(m);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
}
@@ -163,7 +163,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
L1Model model = new L1Model(m);
- Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+ Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();
GsonBuilder builder = new GsonBuilder();
@@ -177,7 +177,7 @@ public class TestPrintableInterface exte
public void testCanopyAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new Canopy(m, 123);
+ Cluster cluster = new Canopy(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
@@ -187,7 +187,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new Canopy(m, 123);
+ Cluster cluster = new Canopy(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
@@ -196,7 +196,7 @@ public class TestPrintableInterface exte
public void testCanopyAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new Canopy(m, 123);
+ Cluster cluster = new Canopy(m, 123);
String[] bindings = { "fee", null, null };
String formatString = cluster.asFormatString(bindings);
System.out.println(formatString);
@@ -207,7 +207,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new Canopy(m, 123);
+ Cluster cluster = new Canopy(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
@@ -216,7 +216,7 @@ public class TestPrintableInterface exte
public void testClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new Cluster(m, 123);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Cluster(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
@@ -226,7 +226,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new Cluster(m, 123);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Cluster(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
@@ -235,7 +235,7 @@ public class TestPrintableInterface exte
public void testClusterAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new Cluster(m, 123);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Cluster(m, 123);
String[] bindings = { "fee", null, "foo" };
String formatString = cluster.asFormatString(bindings);
System.out.println(formatString);
@@ -246,7 +246,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new Cluster(m, 123);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Cluster(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
@@ -255,7 +255,7 @@ public class TestPrintableInterface exte
public void testMSCanopyAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new MeanShiftCanopy(m, 123);
+ Cluster cluster = new MeanShiftCanopy(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
@@ -265,7 +265,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new MeanShiftCanopy(m, 123);
+ Cluster cluster = new MeanShiftCanopy(m, 123);
String formatString = cluster.asFormatString(null);
System.out.println(formatString);
assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
@@ -274,7 +274,7 @@ public class TestPrintableInterface exte
public void testMSCanopyAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Printable cluster = new MeanShiftCanopy(m, 123);
+ Cluster cluster = new MeanShiftCanopy(m, 123);
String[] bindings = { "fee", null, "foo" };
String formatString = cluster.asFormatString(bindings);
System.out.println(formatString);
@@ -285,7 +285,7 @@ public class TestPrintableInterface exte
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
m.assign(d);
- Printable cluster = new MeanShiftCanopy(m, 123);
+ Cluster cluster = new MeanShiftCanopy(m, 123);
String[] bindings = { "fee", null, "foo" };
String formatString = cluster.asFormatString(bindings);
System.out.println(formatString);
Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java?rev=931732&r1=931731&r2=931732&view=diff
==============================================================================
--- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java (original)
+++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java Thu Apr 8 00:08:00 2010
@@ -50,7 +50,7 @@ import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.jobcontrol.Job;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.math.Vector;
@@ -63,13 +63,21 @@ public final class ClusterDumper {
private static final Logger log = LoggerFactory.getLogger(ClusterDumper.class);
private final String seqFileDir;
+
private final String pointsDir;
+
private String termDictionary;
+
private String dictionaryFormat;
+
private String outputFile;
+
private int subString = Integer.MAX_VALUE;
+
private int numTopFeatures = 10;
- private Map<String,List<String>> clusterIdToPoints = null;
+
+ private Map<String, List<String>> clusterIdToPoints = null;
+
private boolean useJSON = false;
public ClusterDumper(String seqFileDir, String pointsDir) throws IOException {
@@ -127,25 +135,25 @@ public final class ClusterDumper {
FileSystem fs = FileSystem.get(path.toUri(), conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
Writable key = (Writable) reader.getKeyClass().newInstance();
- ClusterBase value = (ClusterBase) reader.getValueClass().newInstance();
+ Writable value = (Writable) reader.getValueClass().newInstance();
while (reader.next(key, value)) {
- Vector center = value.getCenter();
- String fmtStr = useJSON ? center.asFormatString() : VectorHelper.vectorToString(center, dictionary);
- writer.append("Id: ").append(String.valueOf(value.getId())).append(":");
- writer.append("name:").append(center.getName());
- if (subString > 0) {
+ Cluster cluster = (Cluster) value;
+ String fmtStr = useJSON ? cluster.asJsonString() : cluster.asFormatString(dictionary);
+ if (subString > 0 && fmtStr.length() > subString) {
writer.append(":").append(fmtStr.substring(0, Math.min(subString, fmtStr.length())));
- }
+ } else
+ writer.append(fmtStr);
+
writer.append('\n');
if (dictionary != null) {
- String topTerms = getTopFeatures(center, dictionary, numTopFeatures);
+ String topTerms = getTopFeatures(cluster.getCenter(), dictionary, numTopFeatures);
writer.write("\tTop Terms: ");
writer.write(topTerms);
writer.write('\n');
}
- List<String> points = clusterIdToPoints.get(String.valueOf(value.getId()));
+ List<String> points = clusterIdToPoints.get(String.valueOf(cluster.getId()));
if (points != null) {
writer.write("\tPoints: ");
for (Iterator<String> iterator = points.iterator(); iterator.hasNext();) {
@@ -183,7 +191,7 @@ public final class ClusterDumper {
this.subString = subString;
}
- public Map<String,List<String>> getClusterIdToPoints() {
+ public Map<String, List<String>> getClusterIdToPoints() {
return clusterIdToPoints;
}
@@ -210,37 +218,36 @@ public final class ClusterDumper {
GroupBuilder gbuilder = new GroupBuilder();
Option seqOpt = obuilder.withLongName("seqFileDir").withRequired(false).withArgument(
- abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription(
- "The directory containing Sequence Files for the Clusters").withShortName("s").create();
+ abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The directory containing Sequence Files for the Clusters").withShortName("s").create();
Option outputOpt = obuilder.withLongName("output").withRequired(false).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
- "The output file. If not specified, dumps to the console").withShortName("o").create();
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The output file. If not specified, dumps to the console").withShortName("o").create();
Option substringOpt = obuilder.withLongName("substring").withRequired(false).withArgument(
- abuilder.withName("substring").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of chars of the asFormatString() to print").withShortName("b").create();
+ abuilder.withName("substring").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of chars of the asFormatString() to print").withShortName("b").create();
Option numWordsOpt = obuilder.withLongName("numWords").withRequired(false).withArgument(
- abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of top terms to print").withShortName("n").create();
+ abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of top terms to print").withShortName("n").create();
Option centroidJSonOpt = obuilder.withLongName("json").withRequired(false).withDescription(
- "Output the centroid as JSON. Otherwise it substitues in the terms for vector cell entries")
+ "Output the centroid as JSON. Otherwise it substitues in the terms for vector cell entries")
.withShortName("j").create();
Option pointsOpt = obuilder.withLongName("pointsDir").withRequired(false).withArgument(
- abuilder.withName("pointsDir").withMinimum(1).withMaximum(1).create()).withDescription(
- "The directory containing points sequence files mapping input vectors to their cluster. "
- + "If specified, then the program will output the points associated with a cluster").withShortName(
- "p").create();
+ abuilder.withName("pointsDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The directory containing points sequence files mapping input vectors to their cluster. "
+ + "If specified, then the program will output the points associated with a cluster").withShortName("p")
+ .create();
Option dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument(
- abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription(
- "The dictionary file. ").withShortName("d").create();
+ abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create())
+ .withDescription("The dictionary file. ").withShortName("d").create();
Option dictTypeOpt = obuilder.withLongName("dictionaryType").withRequired(false).withArgument(
- abuilder.withName("dictionaryType").withMinimum(1).withMaximum(1).create()).withDescription(
- "The dictionary file type (text|sequencefile)").withShortName("dt").create();
- Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
- .create();
-
- Group group = gbuilder.withName("Options").withOption(helpOpt).withOption(seqOpt).withOption(outputOpt)
- .withOption(substringOpt).withOption(pointsOpt).withOption(centroidJSonOpt).withOption(dictOpt)
- .withOption(dictTypeOpt).withOption(numWordsOpt).create();
+ abuilder.withName("dictionaryType").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The dictionary file type (text|sequencefile)").withShortName("dt").create();
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
+
+ Group group = gbuilder.withName("Options").withOption(helpOpt).withOption(seqOpt).withOption(outputOpt).withOption(
+ substringOpt).withOption(pointsOpt).withOption(centroidJSonOpt).withOption(dictOpt).withOption(dictTypeOpt)
+ .withOption(numWordsOpt).create();
try {
Parser parser = new Parser();
@@ -310,8 +317,8 @@ public final class ClusterDumper {
this.useJSON = json;
}
- private static Map<String,List<String>> readPoints(String pointsPathDir, JobConf conf) throws IOException {
- SortedMap<String,List<String>> result = new TreeMap<String,List<String>>();
+ private static Map<String, List<String>> readPoints(String pointsPathDir, JobConf conf) throws IOException {
+ SortedMap<String, List<String>> result = new TreeMap<String, List<String>>();
File[] children = new File(pointsPathDir).listFiles(new FilenameFilter() {
@Override
@@ -355,6 +362,7 @@ public final class ClusterDumper {
static class TermIndexWeight {
int index = -1;
+
double weight = 0;
TermIndexWeight(int index, double weight) {
@@ -381,7 +389,7 @@ public final class ClusterDumper {
}
});
- List<Pair<String,Double>> topTerms = new LinkedList<Pair<String,Double>>();
+ List<Pair<String, Double>> topTerms = new LinkedList<Pair<String, Double>>();
for (int i = 0; (i < vectorTerms.size()) && (i < numTerms); i++) {
int index = vectorTerms.get(i).index;
@@ -390,12 +398,12 @@ public final class ClusterDumper {
log.error("Dictionary entry missing for {}", index);
continue;
}
- topTerms.add(new Pair<String,Double>(dictTerm, vectorTerms.get(i).weight));
+ topTerms.add(new Pair<String, Double>(dictTerm, vectorTerms.get(i).weight));
}
StringBuilder sb = new StringBuilder();
- for (Pair<String,Double> item : topTerms) {
+ for (Pair<String, Double> item : topTerms) {
String term = item.getFirst();
sb.append("\n\t\t");
sb.append(StringUtils.rightPad(term, 40));
Added: lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java?rev=931732&view=auto
==============================================================================
--- lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java (added)
+++ lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java Thu Apr 8 00:08:00 2010
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import junit.framework.Assert;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.store.RAMDirectory;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.clustering.canopy.CanopyClusteringJob;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
+import org.apache.mahout.clustering.dirichlet.DirichletDriver;
+import org.apache.mahout.clustering.dirichlet.models.L1ModelDistribution;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.utils.clustering.ClusterDumper;
+import org.apache.mahout.utils.vectors.TFIDF;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.utils.vectors.Weight;
+import org.apache.mahout.utils.vectors.lucene.CachedTermInfo;
+import org.apache.mahout.utils.vectors.lucene.LuceneIterable;
+import org.apache.mahout.utils.vectors.lucene.TFDFMapper;
+import org.apache.mahout.utils.vectors.lucene.VectorMapper;
+
+public class TestClusterDumper extends MahoutTestCase {
+
+ private List<VectorWritable> sampleData;
+
+ private FileSystem fs;
+
+ private static final String[] DOCS = { "The quick red fox jumped over the lazy brown dogs.",
+ "The quick brown fox jumped over the lazy red dogs.", "The quick red cat jumped over the lazy brown dogs.",
+ "The quick brown cat jumped over the lazy red dogs.", "Mary had a little lamb whose fleece was white as snow.",
+ "Mary had a little goat whose fleece was white as snow.",
+ "Mary had a little lamb whose fleece was black as tar.",
+ "Dick had a little goat whose fleece was white as snow.", "Moby Dick is a story of a whale and a man obsessed.",
+ "Moby Bob is a story of a walrus and a man obsessed.", "Moby Dick is a story of a whale and a crazy man.",
+ "The robber wore a black fleece jacket and a baseball cap.",
+ "The robber wore a red fleece jacket and a baseball cap.",
+ "The robber wore a white fleece jacket and a baseball cap.",
+ "The English Springer Spaniel is the best of all dogs." };
+
+ @Override
+ protected void setUp() throws Exception {
+ super.setUp();
+ RandomUtils.useTestSeed();
+ Configuration conf = new Configuration();
+ fs = FileSystem.get(conf);
+ // Create testdata directory
+ File f = new File("testdata");
+ if (!f.exists()) {
+ f.mkdir();
+ }
+ f = new File("testdata/points");
+ if (!f.exists()) {
+ f.mkdir();
+ }
+ f = new File("output");
+ rmDir(f);
+ // Create test data
+ getSampleData(DOCS);
+ ClusteringTestUtils.writePointsToFile(sampleData, "testdata/points/file1", fs, conf);
+ // Run clustering job
+ // Run ClusterDumper test
+ }
+
+ private void rmDir(File f) {
+ if (f != null && f.exists()) {
+ if (f.isDirectory())
+ for (File g : f.listFiles())
+ rmDir(g);
+ f.delete();
+ }
+ }
+
+ private void getSampleData(String[] docs2) throws IOException {
+ sampleData = new ArrayList<VectorWritable>();
+ RAMDirectory directory = new RAMDirectory();
+ IndexWriter writer = new IndexWriter(directory, new StandardAnalyzer(Version.LUCENE_CURRENT), true,
+ IndexWriter.MaxFieldLength.UNLIMITED);
+ for (int i = 0; i < docs2.length; i++) {
+ Document doc = new Document();
+ Field id = new Field("id", "doc_" + i, Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS);
+ doc.add(id);
+ // Store both position and offset information
+ Field text = new Field("content", docs2[i], Field.Store.NO, Field.Index.ANALYZED, Field.TermVector.YES);
+ doc.add(text);
+ writer.addDocument(doc);
+ }
+ writer.close();
+ IndexReader reader = IndexReader.open(directory, true);
+ Weight weight = new TFIDF();
+ TermInfo termInfo = new CachedTermInfo(reader, "content", 1, 100);
+ VectorMapper mapper = new TFDFMapper(reader, weight, termInfo);
+ LuceneIterable iterable = new LuceneIterable(reader, "id", "content", mapper);
+
+ int i = 0;
+ for (Vector vector : iterable) {
+ Assert.assertNotNull(vector);
+ System.out.println("Vector[" + i++ + "]=" + ClusterBase.formatVector(vector, null));
+ sampleData.add(new VectorWritable(vector));
+ }
+ }
+
+ public void testCanopy() throws Exception { // now run the Job
+ CanopyClusteringJob.runJob("testdata/points", "output", EuclideanDistanceMeasure.class.getName(), 8, 4);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper("output/canopies", null);
+ clusterDumper.printClusters();
+ }
+
+ public void testKmeans() throws Exception {
+ // now run the Canopy job to prime kMeans canopies
+ CanopyDriver.runJob("testdata/points", "testdata/canopies", EuclideanDistanceMeasure.class.getName(), 8, 4);
+ // now run the KMeans job
+ KMeansDriver.runJob("testdata/points", "testdata/canopies", "output", EuclideanDistanceMeasure.class.getName(),
+ 0.001, 10, 1);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper("output/clusters-1", null);
+ clusterDumper.printClusters();
+ }
+
+ public void testDirichlet() throws Exception {
+ Vector prototype = sampleData.get(0).get();
+ DirichletDriver.runJob("testdata/points", "output",
+ new L1ModelDistribution(sampleData.get(0)).getClass().getName(), prototype.getClass().getName(), prototype
+ .size(), 15, 10, 1.0, 1);
+ // run ClusterDumper
+ ClusterDumper clusterDumper = new ClusterDumper("output/state-10", null);
+ clusterDumper.printClusters();
+ }
+}