You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/03/19 20:21:45 UTC
svn commit: r756145 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/
Author: jeastman
Date: Thu Mar 19 19:21:44 2009
New Revision: 756145
URL: http://svn.apache.org/viewvc?rev=756145&view=rev
Log:
- new job and model classes for Dirichlet over the synthetic control data
- removed some unneeded annotations
Added:
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelDistributionAdapter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=756145&r1=756144&r2=756145&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java Thu Mar 19 19:21:44 2009
@@ -57,7 +57,6 @@
return gson.toJson(this, typeOfModel);
}
- @SuppressWarnings("unchecked")
public static DirichletCluster<Vector> fromFormatString(String formatString) {
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=756145&r1=756144&r2=756145&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Thu Mar 19 19:21:44 2009
@@ -61,7 +61,6 @@
state = getDirichletState(job);
}
- @SuppressWarnings("unchecked")
public static DirichletState<Vector> getDirichletState(JobConf job) {
String statePath = job.get(DirichletDriver.STATE_IN_KEY);
String modelFactory = job.get(DirichletDriver.MODEL_FACTORY_KEY);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java?rev=756145&r1=756144&r2=756145&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java Thu Mar 19 19:21:44 2009
@@ -34,7 +34,6 @@
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
-@SuppressWarnings("unchecked")
public class JsonModelAdapter implements JsonSerializer<Model<?>>,
JsonDeserializer<Model<?>> {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelDistributionAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelDistributionAdapter.java?rev=756145&r1=756144&r2=756145&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelDistributionAdapter.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelDistributionAdapter.java Thu Mar 19 19:21:44 2009
@@ -30,7 +30,6 @@
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
-@SuppressWarnings("unchecked")
public class JsonModelDistributionAdapter implements
JsonSerializer<ModelDistribution<?>>, JsonDeserializer<ModelDistribution<?>> {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java?rev=756145&r1=756144&r2=756145&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java Thu Mar 19 19:21:44 2009
@@ -1,6 +1,5 @@
package org.apache.mahout.clustering.dirichlet.models;
-
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
@@ -21,7 +20,7 @@
/**
* A model distribution allows us to sample a model from its prior distribution.
*/
-public interface ModelDistribution<O> {
+public interface ModelDistribution<Observation> {
/**
* Return a list of models sampled from the prior
@@ -29,7 +28,7 @@
* @param howMany the int number of models to return
* @return a Model<Observation>[] representing what is known apriori
*/
- Model<O>[] sampleFromPrior(int howMany);
+ Model<Observation>[] sampleFromPrior(int howMany);
/**
* Return a list of models sampled from the posterior
@@ -37,6 +36,6 @@
* @param posterior the Model<Observation>[] after observations
* @return a Model<Observation>[] representing what is known apriori
*/
- Model<O>[] sampleFromPosterior(Model<O>[] posterior);
+ Model<Observation>[] sampleFromPosterior(Model<Observation>[] posterior);
}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=756145&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java Thu Mar 19 19:21:44 2009
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.syntheticcontrol.dirichlet;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.dirichlet.DirichletCluster;
+import org.apache.mahout.clustering.dirichlet.DirichletDriver;
+import org.apache.mahout.clustering.dirichlet.DirichletJob;
+import org.apache.mahout.clustering.dirichlet.DirichletMapper;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.syntheticcontrol.canopy.InputDriver;
+import org.apache.mahout.matrix.Vector;
+
+public class Job {
+ private Job() {
+ }
+
+ public static void main(String[] args) throws IOException,
+ ClassNotFoundException, InstantiationException, IllegalAccessException {
+ if (args.length == 7) {
+ String input = args[0];
+ String output = args[1];
+ String modelFactory = args[2];
+ int numClusters = Integer.parseInt(args[3]);
+ int maxIterations = Integer.parseInt(args[4]);
+ double alpha_0 = Double.parseDouble(args[5]);
+ int numReducers = Integer.parseInt(args[6]);
+ runJob(input, output, modelFactory, numClusters, maxIterations, alpha_0,
+ numReducers);
+ } else
+ runJob(
+ "testdata",
+ "output",
+ "org.apache.mahout.clustering.syntheticcontrol.dirichlet.NormalScModelDistribution",
+ 10, 5, 1.0, 1);
+ }
+
+ /**
+ * Run the job using supplied arguments, deleting the output directory if it
+ * exists beforehand
+ *
+ * @param input the directory pathname for input points
+ * @param output the directory pathname for output points
+ * @param modelFactory the ModelDistribution class name
+ * @param numModels the number of Models
+ * @param maxIterations the maximum number of iterations
+ * @param alpha_0 the alpha0 value for the DirichletDistribution
+ * @param numReducers the desired number of reducers
+ * @throws IllegalAccessException
+ * @throws InstantiationException
+ * @throws ClassNotFoundException
+ */
+ public static void runJob(String input, String output, String modelFactory,
+ int numModels, int maxIterations, double alpha_0, int numReducers)
+ throws IOException, ClassNotFoundException, InstantiationException,
+ IllegalAccessException {
+ // delete the output directory
+ JobConf conf = new JobConf(DirichletJob.class);
+ Path outPath = new Path(output);
+ FileSystem fs = FileSystem.get(conf);
+ if (fs.exists(outPath)) {
+ fs.delete(outPath, true);
+ }
+ fs.mkdirs(outPath);
+ InputDriver.runJob(input, output + "/data");
+ DirichletDriver.runJob(output + "/data", output + "/state", modelFactory,
+ numModels, maxIterations, alpha_0, numReducers);
+ printResults(output + "/state", modelFactory, maxIterations, numModels,
+ alpha_0);
+ }
+
+ /**
+ * Prints out all of the clusters during each iteration
+ * @param output the String output directory
+ * @param modelDistribution the String class name of the ModelDistribution
+ * @param numIterations the int number of Iterations
+ * @param numModels the int number of models
+ * @param alpha_0 the double alpha_0 value
+ */
+ public static void printResults(String output, String modelDistribution,
+ int numIterations, int numModels, double alpha_0) {
+ List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
+ JobConf conf = new JobConf(KMeansDriver.class);
+ conf.set(DirichletDriver.MODEL_FACTORY_KEY, modelDistribution);
+ conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(numModels));
+ conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(alpha_0));
+ for (int i = 0; i < numIterations; i++) {
+ conf.set(DirichletDriver.STATE_IN_KEY, output + "/state-" + i);
+ clusters.add(DirichletMapper.getDirichletState(conf).clusters);
+ }
+ printResults(clusters, 0);
+
+ }
+
+ /**
+ * Actually prints out the clusters
+ * @param clusters a List of Lists of DirichletClusters
+ * @param significant the minimum number of samples to enable printing a model
+ */
+ private static void printResults(
+ List<List<DirichletCluster<Vector>>> clusters, int significant) {
+ int row = 0;
+ for (List<DirichletCluster<Vector>> r : clusters) {
+ System.out.print("sample[" + row++ + "]= ");
+ for (int k = 0; k < r.size(); k++) {
+ Model<Vector> model = r.get(k).model;
+ if (model.count() > significant) {
+ int total = new Double(r.get(k).totalCount).intValue();
+ System.out.print("m" + k + "(" + total + ")" + model.toString()
+ + ", ");
+ }
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java?rev=756145&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java Thu Mar 19 19:21:44 2009
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.syntheticcontrol.dirichlet;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.matrix.Vector;
+
+public class NormalScModel implements Model<Vector> {
+
+ private static final double sqrt2pi = Math.sqrt(2.0 * Math.PI);
+
+ // the parameters
+ public Vector mean;
+
+ public double sd;
+
+ // the observation statistics, initialized by the first observation
+ int s0 = 0;
+
+ Vector s1;
+
+ Vector s2;
+
+ public NormalScModel() {
+ }
+
+ public NormalScModel(Vector mean, double sd) {
+ this.mean = mean;
+ this.sd = sd;
+ this.s0 = 0;
+ this.s1 = mean.like();
+ this.s2 = mean.like();
+ }
+
+ /**
+ * Return an instance with the same parameters
+ * @return an NormalModel
+ */
+ NormalScModel sample() {
+ return new NormalScModel(mean, sd);
+ }
+
+ @Override
+ public void observe(Vector x) {
+ s0++;
+ if (s1 == null)
+ s1 = x;
+ else
+ s1 = s1.plus(x);
+ if (s2 == null)
+ s2 = x.times(x);
+ else
+ s2 = s2.plus(x.times(x));
+ }
+
+ @Override
+ public void computeParameters() {
+ if (s0 == 0)
+ return;
+ mean = s1.divide(s0);
+ //TODO: is this the average of the 60 component stds??
+ if (s0 > 1)
+ sd = Math.sqrt(s2.times(s0).minus(s1.times(s1)).zSum() / (60 * 60)) / s0;
+ else
+ sd = Double.MIN_VALUE;
+ }
+
+ @Override
+ // TODO: need to revisit this for reasonableness
+ public double pdf(Vector x) {
+ assert x.size() == 60;
+ double sd2 = sd * sd;
+ double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
+ double ex = Math.exp(exp);
+ return ex / (sd * sqrt2pi);
+ }
+
+ @Override
+ public int count() {
+ return s0;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("nm{n=").append(s0).append(" m=[");
+ if (mean != null)
+ for (int i = 0; i < mean.cardinality(); i++)
+ buf.append(String.format("%.2f", mean.get(i))).append(", ");
+ buf.append("] sd=").append(String.format("%.2f", sd)).append('}');
+ return buf.toString();
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java?rev=756145&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java Thu Mar 19 19:21:44 2009
@@ -0,0 +1,53 @@
+package org.apache.mahout.clustering.syntheticcontrol.dirichlet;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+/**
+ * An implementation of the ModelDistribution interface suitable for testing the
+ * DirichletCluster algorithm. Uses a Normal Distribution
+ */
+public class NormalScModelDistribution implements ModelDistribution<Vector> {
+
+ @Override
+ public Model<Vector>[] sampleFromPrior(int howMany) {
+ Model<Vector>[] result = new NormalScModel[howMany];
+ for (int i = 0; i < howMany; i++) {
+ DenseVector mean = new DenseVector(60);
+ for (int j = 0; j < 60; j++)
+ mean.set(j, UncommonDistributions.rNorm(30, 0.5));
+ result[i] = new NormalScModel(mean, 1);
+ }
+ return result;
+ }
+
+ @Override
+ public Model<Vector>[] sampleFromPosterior(Model<Vector>[] posterior) {
+ Model<Vector>[] result = new NormalScModel[posterior.length];
+ for (int i = 0; i < posterior.length; i++) {
+ NormalScModel m = (NormalScModel) posterior[i];
+ result[i] = m.sample();
+ }
+ return result;
+ }
+}
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt?rev=756145&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt Thu Mar 19 19:21:44 2009
@@ -0,0 +1,3 @@
+This ModelDistribution and Model are very tentative and do not yield very explainable results. Readers
+are encouraged to experiment with the prior model parameters, the pdf and standard deviation calculations
+to hopefully obtain results that do not so rapidly converge on the m0 cluster.
\ No newline at end of file