You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/03/16 01:17:08 UTC
svn commit: r754797 [2/2] - in /lucene/mahout/trunk: ./ core/
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
core/src/test/java/org/apache/mahout/clustering/dirichlet/ exampl...
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,57 @@
+package org.apache.mahout.clustering.dirichlet.models;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+/**
+ * An implementation of the ModelDistribution interface suitable for testing the
+ * DirichletCluster algorithm. Uses a Normal Distribution to sample the prior
+ * model values.
+ */
+public class SampledNormalDistribution extends NormalModelDistribution
+ implements ModelDistribution<Vector> {
+
+ /* (non-Javadoc)
+ * @see org.apache.mahout.clustering.dirichlet.ModelDistribution#sampleFromPrior(int)
+ */
+ public Model<Vector>[] sampleFromPrior(int howMany) {
+ Model<Vector>[] result = new SampledNormalModel[howMany];
+ for (int i = 0; i < howMany; i++) {
+ double[] m = { UncommonDistributions.rNorm(0, 1),
+ UncommonDistributions.rNorm(0, 1) };
+ DenseVector mean = new DenseVector(m);
+ result[i] = new SampledNormalModel(mean, 1);
+ }
+ return result;
+ }
+
+ /* (non-Javadoc)
+ * @see org.apache.mahout.clustering.dirichlet.ModelDistribution#sampleFromPosterior(org.apache.mahout.clustering.dirichlet.Model<Observation>[])
+ */
+ public Model<Vector>[] sampleFromPosterior(Model<Vector>[] posterior) {
+ Model<Vector>[] result = new SampledNormalModel[posterior.length];
+ for (int i = 0; i < posterior.length; i++) {
+ SampledNormalModel m = (SampledNormalModel) posterior[i];
+ result[i] = m.sample();
+ }
+ return result;
+ }
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet.models;
+
+import org.apache.mahout.matrix.Vector;
+
+public class SampledNormalModel extends NormalModel implements Model<Vector> {
+
+ public SampledNormalModel() {
+ super();
+ }
+
+ public SampledNormalModel(Vector mean, double sd) {
+ super(mean, sd);
+ }
+
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("snm{n=").append(s0).append(" m=[");
+ if (mean != null)
+ for (int i = 0; i < mean.cardinality(); i++)
+ buf.append(String.format("%.2f", mean.get(i))).append(", ");
+ buf.append("] sd=").append(String.format("%.2f", sd)).append("}");
+ return buf.toString();
+ }
+
+ /**
+ * Return an instance with the same parameters
+ * @return an SampledNormalModel
+ */
+ NormalModel sample() {
+ return new SampledNormalModel(mean, sd);
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/JsonModelHolderAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/JsonModelHolderAdapter.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/JsonModelHolderAdapter.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/JsonModelHolderAdapter.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.matrix.Vector;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+import com.google.gson.reflect.TypeToken;
+
+@SuppressWarnings("unchecked")
+public class JsonModelHolderAdapter implements JsonSerializer<ModelHolder>,
+ JsonDeserializer<ModelHolder> {
+
+ Type typeOfModel = new TypeToken<Model<Vector>>() {
+ }.getType();
+
+ public JsonElement serialize(ModelHolder src, Type typeOfSrc,
+ JsonSerializationContext context) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = new JsonObject();
+ obj.add("model", new JsonPrimitive(gson.toJson(src.model, typeOfModel)));
+ return obj;
+ }
+
+ public ModelHolder deserialize(JsonElement json, Type typeOfT,
+ JsonDeserializationContext context) throws JsonParseException {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = json.getAsJsonObject();
+ String value = obj.get("model").getAsString();
+ Model m = (Model) gson.fromJson(value, typeOfModel);
+ return new ModelHolder(m);
+ }
+
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/ModelHolder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/ModelHolder.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/ModelHolder.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/ModelHolder.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,14 @@
+package org.apache.mahout.clustering.dirichlet;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+
+class ModelHolder<Observation> {
+ public Model<Observation> model;
+
+ public ModelHolder() {
+ }
+
+ public ModelHolder(Model<Observation> model) {
+ this.model = model;
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,194 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.ArrayList;
+import java.util.List;
+
+import junit.framework.TestCase;
+
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+public class TestDirichletClustering extends TestCase {
+
+ private List<Vector> sampleData;
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see junit.framework.TestCase#setUp()
+ */
+ @Override
+ protected void setUp() throws Exception {
+ super.setUp();
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ sampleData = new ArrayList<Vector>();
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sd double standard deviation of the samples
+ */
+ private void generateSamples(int num, double mx, double my, double sd) {
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=" + sd);
+ for (int i = 0; i < num; i++)
+ sampleData.add(new DenseVector(new double[] {
+ UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd) }));
+ }
+
+ private void printResults(List<Model<Vector>[]> result, int significant) {
+ int row = 0;
+ for (Model<Vector>[] r : result) {
+ System.out.print("sample[" + row++ + "]= ");
+ for (int k = 0; k < r.length; k++) {
+ Model<Vector> model = r[k];
+ if (model.count() > significant)
+ System.out.print(model.toString() + ", ");
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ public void testDirichletCluster100() {
+ System.out.println("testDirichletCluster100");
+ generateSamples(40, 1, 1, 3);
+ generateSamples(30, 1, 0, 0.1);
+ generateSamples(30, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new NormalModelDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 2);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster100s() {
+ System.out.println("testDirichletCluster100s");
+ generateSamples(40, 1, 1, 3);
+ generateSamples(30, 1, 0, 0.1);
+ generateSamples(30, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new SampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 2);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster100as() {
+ System.out.println("testDirichletCluster100as");
+ generateSamples(40, 1, 1, 3);
+ generateSamples(30, 1, 0, 0.1);
+ generateSamples(30, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new AsymmetricSampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 2);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster1000() {
+ System.out.println("testDirichletCluster1000");
+ generateSamples(400, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.1);
+ generateSamples(300, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new NormalModelDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 20);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster1000s() {
+ System.out.println("testDirichletCluster1000s");
+ generateSamples(400, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.1);
+ generateSamples(300, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new SampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 20);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster1000as() {
+ System.out.println("testDirichletCluster1000as");
+ generateSamples(400, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.1);
+ generateSamples(300, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new AsymmetricSampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 20);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster10000() {
+ System.out.println("testDirichletCluster10000");
+ generateSamples(4000, 1, 1, 3);
+ generateSamples(3000, 1, 0, 0.1);
+ generateSamples(3000, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new NormalModelDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 200);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster10000as() {
+ System.out.println("testDirichletCluster10000as");
+ generateSamples(4000, 1, 1, 3);
+ generateSamples(3000, 1, 0, 0.1);
+ generateSamples(3000, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new AsymmetricSampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 200);
+ assertNotNull(result);
+ }
+
+ public void testDirichletCluster10000s() {
+ System.out.println("testDirichletCluster10000s");
+ generateSamples(4000, 1, 1, 3);
+ generateSamples(3000, 1, 0, 0.1);
+ generateSamples(3000, 0, 1, 0.1);
+
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ new SampledNormalDistribution(), 1.0, 10, 1, 0);
+ List<Model<Vector>[]> result = dc.cluster(30);
+ printResults(result, 200);
+ assertNotNull(result);
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,119 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import junit.framework.TestCase;
+
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+public class TestDistributions extends TestCase {
+
+ protected void setUp() throws Exception {
+ super.setUp();
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ }
+
+ public void testRbeta() {
+ for (double i = 0.01; i < 20; i += 0.25)
+ System.out.println("rBeta(6,1," + i + ")="
+ + UncommonDistributions.rBeta(6, 1, i).asFormatString());
+ }
+
+ public void testRchisq() {
+ for (int i = 0; i < 50; i++)
+ System.out
+ .println("rChisq(" + i + ")=" + UncommonDistributions.rChisq(i));
+ }
+
+ public void testRnorm() {
+ for (int i = 1; i < 50; i++)
+ System.out.println("rNorm(6,1," + i + ")="
+ + UncommonDistributions.rNorm(1, i));
+ }
+
+ public void testDnorm() {
+ for (int i = -30; i < 30; i++) {
+ double d = (i * 0.1);
+ double dnorm = UncommonDistributions.dNorm(d, 0, 1);
+ byte[] bar = new byte[(int) (dnorm * 100)];
+ for (int j = 0; j < bar.length; j++)
+ bar[j] = '*';
+ String baz = new String(bar);
+ System.out.println(baz);
+ }
+ }
+
+ public void testDnorm2() {
+ for (int i = -30; i < 30; i++) {
+ double d = (i * 0.1);
+ double dnorm = UncommonDistributions.dNorm(d, 0, 2);
+ byte[] bar = new byte[(int) (dnorm * 100)];
+ for (int j = 0; j < bar.length; j++)
+ bar[j] = '*';
+ String baz = new String(bar);
+ System.out.println(baz);
+ }
+ }
+
+ public void testDnorm1() {
+ for (int i = -10; i < 10; i++) {
+ double d = (i * 0.1);
+ double dnorm = UncommonDistributions.dNorm(d, 0, 0.2);
+ byte[] bar = new byte[(int) (dnorm * 100)];
+ for (int j = 0; j < bar.length; j++)
+ bar[j] = '*';
+ String baz = new String(bar);
+ System.out.println(baz);
+ }
+ }
+
+ public void testRmultinom1() {
+ double[] b = { 0.4, 0.6 };
+ Vector v = new DenseVector(b);
+ Vector t = v.like();
+ for (int i = 1; i <= 100; i++) {
+ Vector multinom = UncommonDistributions.rMultinom(100, v);
+ t = t.plus(multinom);
+ }
+ System.out.println("sum(rMultinom(" + 100 + ", [0.4, 0.6]))/100="
+ + t.divide(100).asFormatString());
+
+ }
+
+ public void testRmultinom2() {
+ double[] b = { 0.1, 0.2, 0.7 };
+ Vector v = new DenseVector(b);
+ Vector t = v.like();
+ for (int i = 1; i <= 100; i++) {
+ Vector multinom = UncommonDistributions.rMultinom(100, v);
+ t = t.plus(multinom);
+ }
+ System.out.println("sum(rMultinom(" + 100 + ", [ 0.1, 0.2, 0.7 ]))/100="
+ + t.divide(100).asFormatString());
+
+ }
+
+ public void testRmultinom() {
+ double[] b = { 0.1, 0.2, 0.8 };
+ Vector v = new DenseVector(b);
+ for (int i = 1; i <= 100; i++)
+ System.out.println("rMultinom(" + 100 + ", [0.1, 0.2, 0.8])="
+ + UncommonDistributions.rMultinom(100, v).asFormatString());
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,514 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DummyOutputCollector;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
+public class TestMapReduce extends TestCase {
+
+ private List<Vector> sampleData = new ArrayList<Vector>();
+
+ /**
+ * Generate random samples and add them to the sampleData
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sd double standard deviation of the samples
+ */
+ private void generateSamples(int num, double mx, double my, double sd) {
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=" + sd);
+ for (int i = 0; i < num; i++)
+ sampleData.add(new DenseVector(new double[] {
+ UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd) }));
+ }
+
+ public static void writePointsToFileWithPayload(List<Vector> points,
+ String fileName, String payload) throws IOException {
+ BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
+ new FileOutputStream(fileName), Charset.forName("UTF-8")));
+ for (Vector point : points) {
+ output.write(point.asFormatString());
+ output.write(payload);
+ output.write('\n');
+ }
+ output.flush();
+ output.close();
+ }
+
+ /* (non-Javadoc)
+ * @see junit.framework.TestCase#setUp()
+ */
+ @Override
+ protected void setUp() throws Exception {
+ super.setUp();
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ File f = new File("input");
+ if (!f.exists())
+ f.mkdir();
+ }
+
+ /**
+ * Test the basic Mapper
+ * @throws Exception
+ */
+ public void testMapper() throws Exception {
+ generateSamples(10, 0, 0, 1);
+ DirichletState<Vector> state = new DirichletState<Vector>(
+ new NormalModelDistribution(), 5, 1, 0, 0);
+ DirichletMapper mapper = new DirichletMapper();
+ mapper.configure(state);
+
+ DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
+ for (Vector v : sampleData)
+ mapper.map(null, new Text(v.asFormatString()), collector, null);
+ Map<String, List<Text>> data = collector.getData();
+ // this seed happens to produce two partitions, but they work
+ assertEquals("output size", 3, data.size());
+ }
+
+ /**
+ * Test the basic Reducer
+ * @throws Exception
+ */
+ public void testReducer() throws Exception {
+ generateSamples(100, 0, 0, 1);
+ generateSamples(100, 2, 0, 1);
+ generateSamples(100, 0, 2, 1);
+ generateSamples(100, 2, 2, 1);
+ DirichletState<Vector> state = new DirichletState<Vector>(
+ new SampledNormalDistribution(), 20, 1, 1, 0);
+ DirichletMapper mapper = new DirichletMapper();
+ mapper.configure(state);
+
+ DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+ for (Vector v : sampleData)
+ mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
+ Map<String, List<Text>> data = mapCollector.getData();
+ // this seed happens to produce three partitions, but they work
+ assertEquals("output size", 7, data.size());
+
+ DirichletReducer reducer = new DirichletReducer();
+ reducer.configure(state);
+ DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+ for (String key : mapCollector.getKeys())
+ reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
+ reduceCollector, null);
+
+ Model<Vector>[] newModels = reducer.newModels;
+ state.update(newModels);
+ }
+
+ private void printModels(List<Model<Vector>[]> results, int significant) {
+ int row = 0;
+ for (Model<Vector>[] r : results) {
+ System.out.print("sample[" + row++ + "]= ");
+ for (int k = 0; k < r.length; k++) {
+ Model<Vector> model = r[k];
+ if (model.count() > significant) {
+ System.out.print("m" + k + model.toString() + ", ");
+ }
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ /**
+ * Test the Mapper and Reducer in an iteration loop
+ * @throws Exception
+ */
+ public void testMRIterations() throws Exception {
+ generateSamples(100, 0, 0, 1);
+ generateSamples(100, 2, 0, 1);
+ generateSamples(100, 0, 2, 1);
+ generateSamples(100, 2, 2, 1);
+ DirichletState<Vector> state = new DirichletState<Vector>(
+ new SampledNormalDistribution(), 20, 1.0, 1, 0);
+
+ List<Model<Vector>[]> models = new ArrayList<Model<Vector>[]>();
+
+ for (int iteration = 0; iteration < 10; iteration++) {
+ DirichletMapper mapper = new DirichletMapper();
+ mapper.configure(state);
+ DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+ for (Vector v : sampleData)
+ mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
+
+ DirichletReducer reducer = new DirichletReducer();
+ reducer.configure(state);
+ DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+ for (String key : mapCollector.getKeys())
+ reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
+ reduceCollector, null);
+
+ Model<Vector>[] newModels = reducer.newModels;
+ state.update(newModels);
+ models.add(newModels);
+ }
+ printModels(models, 0);
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testNormalModelSerialization() {
+ double[] m = { 1.1, 2.2 };
+ Model model = new NormalModel(new DenseVector(m), 3.3);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(model);
+ Model model2 = gson.fromJson(jsonString, NormalModel.class);
+ assertEquals("models", model.toString(), model2.toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testNormalModelDistributionSerialization() {
+ NormalModelDistribution dist = new NormalModelDistribution();
+ Model[] models = dist.sampleFromPrior(20);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(models);
+ Model[] models2 = gson.fromJson(jsonString, NormalModel[].class);
+ assertEquals("models", models.length, models2.length);
+ for (int i = 0; i < models.length; i++)
+ assertEquals("model[" + i + "]", models[i].toString(), models2[i]
+ .toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testSampledNormalModelSerialization() {
+ double[] m = { 1.1, 2.2 };
+ Model model = new SampledNormalModel(new DenseVector(m), 3.3);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(model);
+ Model model2 = gson.fromJson(jsonString, SampledNormalModel.class);
+ assertEquals("models", model.toString(), model2.toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testSampledNormalDistributionSerialization() {
+ SampledNormalDistribution dist = new SampledNormalDistribution();
+ Model[] models = dist.sampleFromPrior(20);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(models);
+ Model[] models2 = gson.fromJson(jsonString, SampledNormalModel[].class);
+ assertEquals("models", models.length, models2.length);
+ for (int i = 0; i < models.length; i++)
+ assertEquals("model[" + i + "]", models[i].toString(), models2[i]
+ .toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testAsymmetricSampledNormalModelSerialization() {
+ double[] m = { 1.1, 2.2 };
+ double[] s = { 3.3, 4.4 };
+ Model model = new AsymmetricSampledNormalModel(new DenseVector(m),
+ new DenseVector(s));
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(model);
+ Model model2 = gson
+ .fromJson(jsonString, AsymmetricSampledNormalModel.class);
+ assertEquals("models", model.toString(), model2.toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testAsymmetricSampledNormalDistributionSerialization() {
+ AsymmetricSampledNormalDistribution dist = new AsymmetricSampledNormalDistribution();
+ Model[] models = dist.sampleFromPrior(20);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ String jsonString = gson.toJson(models);
+ Model[] models2 = gson.fromJson(jsonString,
+ AsymmetricSampledNormalModel[].class);
+ assertEquals("models", models.length, models2.length);
+ for (int i = 0; i < models.length; i++)
+ assertEquals("model[" + i + "]", models[i].toString(), models2[i]
+ .toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testModelHolderSerialization() {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ builder
+ .registerTypeAdapter(ModelHolder.class, new JsonModelHolderAdapter());
+ Gson gson = builder.create();
+ double[] d = { 1.1, 2.2 };
+ ModelHolder mh = new ModelHolder(new NormalModel(new DenseVector(d), 3.3));
+ String format = gson.toJson(mh);
+ System.out.println(format);
+ ModelHolder mh2 = gson.fromJson(format, ModelHolder.class);
+ assertEquals("mh", mh.model.toString(), mh2.model.toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testModelHolderSerialization2() {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ builder
+ .registerTypeAdapter(ModelHolder.class, new JsonModelHolderAdapter());
+ Gson gson = builder.create();
+ double[] d = { 1.1, 2.2 };
+ double[] s = { 3.3, 4.4 };
+ ModelHolder mh = new ModelHolder(new AsymmetricSampledNormalModel(
+ new DenseVector(d), new DenseVector(s)));
+ String format = gson.toJson(mh);
+ System.out.println(format);
+ ModelHolder mh2 = gson.fromJson(format, ModelHolder.class);
+ assertEquals("mh", mh.model.toString(), mh2.model.toString());
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testStateSerialization() {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(DirichletState.class,
+ new JsonDirichletStateAdapter());
+ Gson gson = builder.create();
+ DirichletState state = new DirichletState(new SampledNormalDistribution(),
+ 20, 1, 1, 0);
+ String format = gson.toJson(state);
+ System.out.println(format);
+ DirichletState state2 = gson.fromJson(format, DirichletState.class);
+ assertNotNull("State2 null", state2);
+ assertEquals("numClusters", state.numClusters, state2.numClusters);
+ assertEquals("modelFactory", state.modelFactory.getClass().getName(),
+ state2.modelFactory.getClass().getName());
+ assertEquals("clusters", state.clusters.size(), state2.clusters.size());
+ assertEquals("mixture", state.mixture.cardinality(), state2.mixture
+ .cardinality());
+ assertEquals("dirichlet", state.offset, state2.offset);
+ }
+
+ /**
+ * Test the Mapper and Reducer using the Driver
+ * @throws Exception
+ */
+ public void testDriverMRIterations() throws Exception {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ g.delete();
+ generateSamples(100, 0, 0, 0.5);
+ generateSamples(100, 2, 0, 0.2);
+ generateSamples(100, 0, 2, 0.3);
+ generateSamples(100, 2, 2, 1);
+ writePointsToFileWithPayload(sampleData, "input/data.txt", "");
+ // Now run the driver
+ DirichletDriver.runJob("input", "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
+ 10, 1.0, 1);
+ // and inspect results
+ List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf.set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ for (int i = 0; i < 11; i++) {
+ conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
+ clusters.add(DirichletMapper.getDirichletState(conf).clusters);
+ }
+ printResults(clusters, 0);
+ }
+
+ private void printResults(List<List<DirichletCluster<Vector>>> clusters,
+ int significant) {
+ int row = 0;
+ for (List<DirichletCluster<Vector>> r : clusters) {
+ System.out.print("sample[" + row++ + "]= ");
+ for (int k = 0; k < r.size(); k++) {
+ Model<Vector> model = r.get(k).model;
+ if (model.count() > significant) {
+ int total = new Double(r.get(k).totalCount).intValue();
+ System.out.print("m" + k + "(" + total + ")" + model.toString()
+ + ", ");
+ }
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ /**
+ * Test the Mapper and Reducer using the Driver
+ * @throws Exception
+ */
+ public void testDriverMnRIterations() throws Exception {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ g.delete();
+ generateSamples(500, 0, 0, 0.5);
+ writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 0, 0.2);
+ writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 0, 2, 0.3);
+ writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 2, 1);
+ writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ // Now run the driver
+ DirichletDriver.runJob("input", "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
+ 15, 1.0, 1);
+ // and inspect results
+ List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf.set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ for (int i = 0; i < 11; i++) {
+ conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
+ clusters.add(DirichletMapper.getDirichletState(conf).clusters);
+ }
+ printResults(clusters, 0);
+ }
+
+ /**
+ * Test the Mapper and Reducer using the Driver
+ * @throws Exception
+ */
+ public void testDriverMnRnIterations() throws Exception {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ g.delete();
+ generateSamples(500, 0, 0, 0.5);
+ writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 0, 0.2);
+ writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 0, 2, 0.3);
+ writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 2, 1);
+ writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ // Now run the driver
+ DirichletDriver.runJob("input", "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
+ 15, 1.0, 2);
+ // and inspect results
+ List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf.set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ for (int i = 0; i < 11; i++) {
+ conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
+ clusters.add(DirichletMapper.getDirichletState(conf).clusters);
+ }
+ printResults(clusters, 0);
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sdx double x-standard deviation of the samples
+ * @param sdy double y-standard deviation of the samples
+ */
+ private void generateSamples(int num, double mx, double my, double sdx,
+ double sdy) {
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=[" + sdx + ", " + sdy + "]");
+ for (int i = 0; i < num; i++)
+ sampleData.add(new DenseVector(new double[] {
+ UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) }));
+ }
+
+ /**
+ * Test the Mapper and Reducer using the Driver
+ * @throws Exception
+ */
+ public void testDriverMnRnIterationsAsymmetric() throws Exception {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ g.delete();
+ generateSamples(500, 0, 0, 0.5, 1.0);
+ writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 0, 0.2, 0.1);
+ writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 0, 2, 0.3, 0.5);
+ writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+ sampleData = new ArrayList<Vector>();
+ generateSamples(500, 2, 2, 1, 0.5);
+ writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ // Now run the driver
+ DirichletDriver
+ .runJob(
+ "input",
+ "output",
+ "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution",
+ 20, 15, 1.0, 2);
+ // and inspect results
+ List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf
+ .set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ for (int i = 0; i < 11; i++) {
+ conf.set(DirichletDriver.STATE_IN_KEY, "output/state-" + i);
+ clusters.add(DirichletMapper.getDirichletState(conf).clusters);
+ }
+ printResults(clusters, 0);
+ }
+
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,68 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class Display2dASNDirichlet extends DisplayDirichlet {
+ public Display2dASNDirichlet() {
+ initialize();
+ this
+ .setTitle("Dirichlet Process Clusters - 2-d Asymmetric Sampled Normal Distribution (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ AsymmetricSampledNormalModel mm = (AsymmetricSampledNormalModel) m;
+ dv.assign(mm.sd.times(3));
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ generate2dSamples();
+ generateResults();
+ new Display2dASNDirichlet();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution());
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,68 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class DisplayASNDirichlet extends DisplayDirichlet {
+ public DisplayASNDirichlet() {
+ initialize();
+ this
+ .setTitle("Dirichlet Process Clusters - Asymmetric Sampled Normal Distribution (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ AsymmetricSampledNormalModel mm = (AsymmetricSampledNormalModel) m;
+ dv.assign(mm.sd.times(3));
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ generateSamples();
+ generateResults();
+ new DisplayASNDirichlet();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution());
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,125 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class DisplayASNOutputState extends DisplayDirichlet {
+ public DisplayASNOutputState() {
+ initialize();
+ this.setTitle("Dirichlet Process Clusters - Map/Reduce Results (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ AsymmetricSampledNormalModel mm = (AsymmetricSampledNormalModel) m;
+ dv.set(0, mm.sd.get(0) * 3);
+ dv.set(1, mm.sd.get(1) * 3);
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ /**
+ * Return the contents of the given file as a String
+ *
+ * @param fileName
+ * the String name of the file
+ * @return the String contents of the file
+ * @throws IOException
+ * if there is an error
+ */
+ public static List<Vector> readFile(String fileName) throws IOException {
+ BufferedReader r = new BufferedReader(new FileReader(fileName));
+ try {
+ List<Vector> results = new ArrayList<Vector>();
+ String line;
+ while ((line = r.readLine()) != null)
+ results.add(DenseVector.decodeFormat(line));
+ return results;
+ } finally {
+ r.close();
+ }
+ }
+
+ private static void getSamples() throws IOException {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ sampleData.addAll(readFile(g.getCanonicalPath()));
+ }
+
+ private static void getResults() throws IOException {
+ result = new ArrayList<Model<Vector>[]>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf
+ .set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ File f = new File("output");
+ for (File g : f.listFiles()) {
+ conf.set(DirichletDriver.STATE_IN_KEY, g.getCanonicalPath());
+ DirichletState<Vector> dirichletState = DirichletMapper
+ .getDirichletState(conf);
+ result.add(dirichletState.getModels());
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ try {
+ getSamples();
+ getResults();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ new DisplayASNOutputState();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new NormalModelDistribution());
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,231 @@
+package org.apache.mahout.clustering.dirichlet;
+
+import java.awt.Color;
+import java.awt.Frame;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.awt.Toolkit;
+import java.awt.event.WindowAdapter;
+import java.awt.event.WindowEvent;
+import java.awt.geom.AffineTransform;
+import java.awt.geom.Ellipse2D;
+import java.awt.geom.Rectangle2D;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.TimesFunction;
+import org.apache.mahout.matrix.Vector;
+
+class DisplayDirichlet extends Frame {
+ private static final long serialVersionUID = 1L;
+
+ int res; //screen resolution
+
+ int ds = 72; //default scale = 72 pixels per inch
+
+ int size = 8; // screen size in inches
+
+ static List<Vector> sampleData = new ArrayList<Vector>();
+
+ static List<Model<Vector>[]> result;
+
+ static double significance = 0.05;
+
+ static List<Vector> sampleParams = new ArrayList<Vector>();
+
+ static Color[] colors = { Color.red, Color.orange, Color.yellow, Color.green,
+ Color.blue, Color.magenta, Color.lightGray };
+
+ /**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ DisplayDirichlet() {
+ initialize();
+ }
+
+ void initialize() {
+ //Get screen resolution
+ res = Toolkit.getDefaultToolkit().getScreenResolution();
+
+ //Set Frame size in inches
+ this.setSize(size * res, size * res);
+ this.setVisible(true);
+ this.setTitle("Dirichlet Process Sample Data");
+
+ //Window listener to terminate program.
+ this.addWindowListener(new WindowAdapter() {
+ public void windowClosing(WindowEvent e) {
+ System.exit(0);
+ }
+ });
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ generateSamples();
+ new DisplayDirichlet();
+ }
+
+ // Override the paint() method
+ public void paint(Graphics g) {
+ Graphics2D g2 = (Graphics2D) g;
+ plotSampleData(g);
+ Vector v = new DenseVector(2);
+ Vector dv = new DenseVector(2);
+ g2.setColor(Color.RED);
+ for (Vector param : sampleParams) {
+ v.set(0, param.get(0));
+ v.set(1, param.get(1));
+ dv.set(0, param.get(2) * 3);
+ dv.set(1, param.get(3) * 3);
+ plotEllipse(g2, v, dv);
+ }
+ }
+
+ void plotSampleData(Graphics g) {
+ Graphics2D g2 = (Graphics2D) g;
+ double sx = (double) res / ds;
+ g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+
+ // plot the axes
+ g2.setColor(Color.BLACK);
+ Vector dv = new DenseVector(2).assign(size / 2);
+ plotRectangle(g2, new DenseVector(2).assign(2), dv);
+ plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+
+ // plot the sample data
+ g2.setColor(Color.DARK_GRAY);
+ dv.assign(0.03);
+ for (Vector v : sampleData)
+ plotRectangle(g2, v, dv);
+ }
+
+ /**
+ * Plot the points on the graphics context
+ * @param g2 a Graphics2D context
+ * @param v a Vector of rectangle centers
+ * @param dv a Vector of rectangle sizes
+ */
+ void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
+ int h = size / 2;
+ double[] flip = { 1, -1 };
+ Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
+ v2 = v2.minus(dv.divide(2));
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2.draw(new Rectangle2D.Double(x * ds, y * ds, dv.get(0) * ds, dv.get(1)
+ * ds));
+ }
+
+ /**
+ * Plot the points on the graphics context
+ * @param g2 a Graphics2D context
+ * @param v a Vector of rectangle centers
+ * @param dv a Vector of rectangle sizes
+ */
+ void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
+ int h = size / 2;
+ double[] flip = { 1, -1 };
+ Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
+ v2 = v2.minus(dv.divide(2));
+ double x = v2.get(0) + h;
+ double y = v2.get(1) + h;
+ g2
+ .draw(new Ellipse2D.Double(x * ds, y * ds, dv.get(0) * ds, dv.get(1)
+ * ds));
+ }
+
+ private static void printModels(List<Model<Vector>[]> results, int significant) {
+ int row = 0;
+ for (Model<Vector>[] r : results) {
+ System.out.print("sample[" + row++ + "]= ");
+ for (int k = 0; k < r.length; k++) {
+ Model<Vector> model = r[k];
+ if (model.count() > significant) {
+ System.out.print("m" + k + model.toString() + ", ");
+ }
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ static void generateSamples() {
+ generateSamples(400, 1, 1, 3);
+ generateSamples(300, 1, 0, 0.5);
+ generateSamples(300, 0, 2, 0.1);
+ }
+
+ static void generate2dSamples() {
+ generate2dSamples(400, 1, 1, 3, 1);
+ generate2dSamples(300, 1, 0, 0.5, 1);
+ generate2dSamples(300, 0, 2, 0.1, 0.5);
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sd double standard deviation of the samples
+ */
+ public static void generateSamples(int num, double mx, double my, double sd) {
+ double[] params = { mx, my, sd, sd };
+ sampleParams.add(new DenseVector(params));
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=" + sd);
+ for (int i = 0; i < num; i++)
+ sampleData.add(new DenseVector(new double[] {
+ UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd) }));
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sdx double x-value standard deviation of the samples
+ * @param sdy double y-value standard deviation of the samples
+ */
+ public static void generate2dSamples(int num, double mx, double my,
+ double sdx, double sdy) {
+ double[] params = { mx, my, sdx, sdy };
+ sampleParams.add(new DenseVector(params));
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=[" + sdx + ", " + sdy + "]");
+ for (int i = 0; i < num; i++)
+ sampleData.add(new DenseVector(new double[] {
+ UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) }));
+ }
+
+ static void generateResults(ModelDistribution<Vector> modelDist) {
+ DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
+ modelDist, 1.0, 10, 2, 2);
+ result = dc.cluster(20);
+ printModels(result, 5);
+ }
+
+ static boolean isSignificant(Model<Vector> model) {
+ return (((double) model.count() / sampleData.size()) > significance);
+ }
+
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,67 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class DisplayNDirichlet extends DisplayDirichlet {
+ public DisplayNDirichlet() {
+ initialize();
+ this.setTitle("Dirichlet Process Clusters - Normal Distribution (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ NormalModel mm = (NormalModel) m;
+ dv.assign(mm.sd * 3);
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ generateSamples();
+ generateResults();
+ new DisplayNDirichlet();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new NormalModelDistribution());
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,123 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class DisplayOutputState extends DisplayDirichlet {
+ public DisplayOutputState() {
+ initialize();
+ this.setTitle("Dirichlet Process Clusters - Map/Reduce Results (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ NormalModel mm = (NormalModel) m;
+ dv.assign(mm.sd * 3);
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ /**
+ * Return the contents of the given file as a String
+ *
+ * @param fileName
+ * the String name of the file
+ * @return the String contents of the file
+ * @throws IOException
+ * if there is an error
+ */
+ public static List<Vector> readFile(String fileName) throws IOException {
+ BufferedReader r = new BufferedReader(new FileReader(fileName));
+ try {
+ List<Vector> results = new ArrayList<Vector>();
+ String line;
+ while ((line = r.readLine()) != null)
+ results.add(DenseVector.decodeFormat(line));
+ return results;
+ } finally {
+ r.close();
+ }
+ }
+
+ private static void getSamples() throws IOException {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ sampleData.addAll(readFile(g.getCanonicalPath()));
+ }
+
+ private static void getResults() throws IOException {
+ result = new ArrayList<Model<Vector>[]>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf.set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
+ File f = new File("output");
+ for (File g : f.listFiles()) {
+ conf.set(DirichletDriver.STATE_IN_KEY, g.getCanonicalPath());
+ DirichletState<Vector> dirichletState = DirichletMapper
+ .getDirichletState(conf);
+ result.add(dirichletState.getModels());
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ try {
+ getSamples();
+ getResults();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ new DisplayOutputState();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new NormalModelDistribution());
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java?rev=754797&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java Mon Mar 16 00:17:07 2009
@@ -0,0 +1,67 @@
+package org.apache.mahout.clustering.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
+import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+class DisplaySNDirichlet extends DisplayDirichlet {
+ public DisplaySNDirichlet() {
+ initialize();
+ this.setTitle("Dirichlet Process Clusters - Sampled Normal Distribution (>"
+ + (int) (significance * 100) + "% of population)");
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public void paint(Graphics g) {
+ super.plotSampleData(g);
+ Graphics2D g2 = (Graphics2D) g;
+
+ Vector dv = new DenseVector(2);
+ int i = result.size() - 1;
+ for (Model<Vector>[] models : result) {
+ g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+ g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+ for (Model<Vector> m : models) {
+ NormalModel mm = (NormalModel) m;
+ dv.assign(mm.sd * 3);
+ if (isSignificant(mm))
+ plotEllipse(g2, mm.mean, dv);
+ }
+ }
+ }
+
+ public static void main(String[] args) {
+ UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ generateSamples();
+ generateResults();
+ new DisplaySNDirichlet();
+ }
+
+ static void generateResults() {
+ DisplayDirichlet.generateResults(new SampledNormalDistribution());
+ }
+}