You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/08/20 23:56:17 UTC
svn commit: r987647 [1/2] - in /mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
core/src/main/java/org/apache/ma...
Author: jeastman
Date: Fri Aug 20 21:56:16 2010
New Revision: 987647
URL: http://svn.apache.org/viewvc?rev=987647&view=rev
Log:
MAHOUT-479: added unit tests to test VectorModelClassifier, ModelDistribution serialization and to ensure
GaussianClusterDistribution and DistanceMeasureClusterDistributions work in Dirichlet. Refactored model
distribution arguments to allow Java developers to provide fully-configured model distributions vs multiple
string parameters. Added distance measure parameter to Dirichlet for use with DMClusterDistributions.
All unit tests run.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java
- copied, changed from r987240, mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonDistanceMeasureAdapter implements JsonSerializer<DistanceMeasure>, JsonDeserializer<DistanceMeasure> {
+
+ private static final Logger log = LoggerFactory.getLogger(JsonDistanceMeasureAdapter.class);
+
+ @Override
+ public JsonElement serialize(DistanceMeasure src, Type typeOfSrc, JsonSerializationContext context) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = new JsonObject();
+ obj.add("class", new JsonPrimitive(src.getClass().getName()));
+ obj.add("model", new JsonPrimitive(gson.toJson(src)));
+ return obj;
+ }
+
+ @Override
+ public DistanceMeasure deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = json.getAsJsonObject();
+ String klass = obj.get("class").getAsString();
+ String model = obj.get("model").getAsString();
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ Class<?> cl = null;
+ try {
+ cl = ccl.loadClass(klass);
+ } catch (ClassNotFoundException e) {
+ log.warn("Error while loading class", e);
+ }
+ return (DistanceMeasure) gson.fromJson(model, cl);
+ }
+}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java (from r987240, mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java&r1=987240&r2=987647&rev=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java Fri Aug 20 21:56:16 2010
@@ -14,11 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.mahout.clustering.dirichlet;
+package org.apache.mahout.clustering;
import java.lang.reflect.Type;
-import org.apache.mahout.clustering.Model;
import org.apache.mahout.math.JsonVectorAdapter;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonModelDistributionAdapter implements JsonSerializer<ModelDistribution<?>>, JsonDeserializer<ModelDistribution<?>> {
+
+ private static final Logger log = LoggerFactory.getLogger(JsonModelDistributionAdapter.class);
+
+ @Override
+ public JsonElement serialize(ModelDistribution<?> src, Type typeOfSrc, JsonSerializationContext context) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = new JsonObject();
+ obj.add("class", new JsonPrimitive(src.getClass().getName()));
+ obj.add("model", new JsonPrimitive(gson.toJson(src)));
+ return obj;
+ }
+
+ @Override
+ public ModelDistribution<?> deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ JsonObject obj = json.getAsJsonObject();
+ String klass = obj.get("class").getAsString();
+ String model = obj.get("model").getAsString();
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ Class<?> cl = null;
+ try {
+ cl = ccl.loadClass(klass);
+ } catch (ClassNotFoundException e) {
+ log.warn("Error while loading class", e);
+ }
+ return (ModelDistribution<?>) gson.fromJson(model, cl);
+ }
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java Fri Aug 20 21:56:16 2010
@@ -18,6 +18,7 @@
package org.apache.mahout.clustering;
+
/** A model distribution allows us to sample a model from its prior distribution. */
public interface ModelDistribution<O> {
@@ -39,4 +40,10 @@ public interface ModelDistribution<O> {
*/
Model<O>[] sampleFromPosterior(Model<O>[] posterior);
+ /**
+ * Return a JSON string representing the receiver. Needed to pass persistent state.
+ * @return a String
+ */
+ String asJsonString();
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java Fri Aug 20 21:56:16 2010
@@ -1,8 +1,11 @@
package org.apache.mahout.clustering;
+import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -20,11 +23,22 @@ public class VectorModelClassifier exten
@Override
public Vector classify(Vector instance) {
Vector pdfs = new DenseVector(models.size());
- int i = 0;
- for (Model<VectorWritable> model : models) {
- pdfs.set(i++, model.pdf(new VectorWritable(instance)));
+ if (models.get(0) instanceof SoftCluster) {
+ List<SoftCluster> clusters = new ArrayList<SoftCluster>();
+ List<Double> distances = new ArrayList<Double>();
+ for (Model<VectorWritable> model : models) {
+ SoftCluster sc = (SoftCluster) model;
+ clusters.add(sc);
+ distances.add(sc.getMeasure().distance(instance, sc.getCenter()));
+ }
+ return new FuzzyKMeansClusterer().computePi(clusters, distances);
+ } else {
+ int i = 0;
+ for (Model<VectorWritable> model : models) {
+ pdfs.set(i++, model.pdf(new VectorWritable(instance)));
+ }
+ return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
}
- return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
}
@Override
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Fri Aug 20 21:56:16 2010
@@ -36,13 +36,16 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
import org.apache.mahout.clustering.kmeans.OutputLogFilter;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -53,7 +56,7 @@ public class DirichletDriver extends Abs
public static final String STATE_IN_KEY = "org.apache.mahout.clustering.dirichlet.stateIn";
- public static final String MODEL_FACTORY_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
+ public static final String MODEL_DISTRIBUTION_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
public static final String MODEL_PROTOTYPE_KEY = "org.apache.mahout.clustering.dirichlet.modelPrototype";
@@ -97,6 +100,7 @@ public class DirichletDriver extends Abs
"mp",
"The ModelDistribution prototype Vector class name. Defaults to RandomAccessSparseVector",
RandomAccessSparseVector.class.getName());
+ addOption(DefaultOptionCreator.distanceMeasureOption().withRequired(false).create());
addOption(DefaultOptionCreator.emitMostLikelyOption().create());
addOption(DefaultOptionCreator.thresholdOption().create());
addOption(DefaultOptionCreator.numReducersOption().create());
@@ -113,6 +117,7 @@ public class DirichletDriver extends Abs
}
String modelFactory = getOption(MODEL_DISTRIBUTION_CLASS_OPTION);
String modelPrototype = getOption(MODEL_PROTOTYPE_CLASS_OPTION);
+ String distanceMeasure = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
int numModels = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
int numReducers = Integer.parseInt(getOption(DefaultOptionCreator.MAX_REDUCERS_OPTION));
int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
@@ -121,11 +126,16 @@ public class DirichletDriver extends Abs
double alpha0 = Double.parseDouble(getOption(ALPHA_OPTION));
boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION);
boolean runSequential = (getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(DefaultOptionCreator.SEQUENTIAL_METHOD));
+ int prototypeSize = readPrototypeSize(input);
+
+ AbstractVectorModelDistribution modelDistribution = createModelDistribution(modelFactory,
+ modelPrototype,
+ distanceMeasure,
+ prototypeSize);
job(input,
output,
- modelFactory,
- modelPrototype,
+ modelDistribution,
numModels,
maxIterations,
alpha0,
@@ -138,15 +148,48 @@ public class DirichletDriver extends Abs
}
/**
+ * Create an instance of AbstractVectorModelDistribution from the given command line arguments
+ * @param modelFactory
+ * @param modelPrototype
+ * @param distanceMeasure
+ * @param prototypeSize
+ * @return
+ * @throws ClassNotFoundException
+ * @throws InstantiationException
+ * @throws IllegalAccessException
+ * @throws NoSuchMethodException
+ * @throws InvocationTargetException
+ */
+ public static AbstractVectorModelDistribution createModelDistribution(String modelFactory,
+ String modelPrototype,
+ String distanceMeasure,
+ int prototypeSize) throws ClassNotFoundException,
+ InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory)
+ .asSubclass(AbstractVectorModelDistribution.class);
+ AbstractVectorModelDistribution modelDistribution = cl.newInstance();
+
+ Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
+ Constructor<? extends Vector> v = vcl.getConstructor(int.class);
+ modelDistribution.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
+
+ if (modelDistribution instanceof DistanceMeasureClusterDistribution) {
+ Class<? extends DistanceMeasure> measureCl = ccl.loadClass(distanceMeasure).asSubclass(DistanceMeasure.class);
+ DistanceMeasure measure = measureCl.newInstance();
+ ((DistanceMeasureClusterDistribution) modelDistribution).setMeasure(measure);
+ }
+ return modelDistribution;
+ }
+
+ /**
* Run the job using supplied arguments on a new driver instance (convenience)
*
* @param input
* the directory pathname for input points
* @param output
* the directory pathname for output points
- * @param modelFactory
- * the String ModelDistribution class name to use
- * @param modelPrototype
+ * @param modelDistribution
* the String class name of the model prototype
* @param numClusters
* the number of models
@@ -166,8 +209,7 @@ public class DirichletDriver extends Abs
*/
public static void runJob(Path input,
Path output,
- String modelFactory,
- String modelPrototype,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
int maxIterations,
double alpha0,
@@ -180,8 +222,7 @@ public class DirichletDriver extends Abs
new DirichletDriver().job(input,
output,
- modelFactory,
- modelPrototype,
+ modelDistribution,
numClusters,
maxIterations,
alpha0,
@@ -196,37 +237,21 @@ public class DirichletDriver extends Abs
* Creates a DirichletState object from the given arguments. Note that the modelFactory is presumed to be a
* subclass of VectorModelDistribution that can be initialized with a concrete Vector prototype.
*
- * @param modelFactory
- * a String which is the class name of the model factory
- * @param modelPrototype
- * a String which is the class name of the Vector used to initialize the factory
- * @param prototypeSize
- * an int number of dimensions of the model prototype vector
- * @param numModels
- * an int number of models to be created
- * @param alpha0
- * the double alpha_0 argument to the algorithm
+ * @param modelDistribution the ModelDistribution
+ * @param numModels an int number of models to be created
+ * @param alpha0 the double alpha_0 argument to the algorithm
* @return an initialized DirichletState
*/
- static DirichletState createState(String modelFactory, String modelPrototype, int prototypeSize, int numModels, double alpha0)
+ static DirichletState createState(ModelDistribution<VectorWritable> modelDistribution, int numModels, double alpha0)
throws ClassNotFoundException, InstantiationException, IllegalAccessException, SecurityException, NoSuchMethodException,
IllegalArgumentException, InvocationTargetException {
-
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory)
- .asSubclass(AbstractVectorModelDistribution.class);
- AbstractVectorModelDistribution factory = cl.newInstance();
-
- Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
- Constructor<? extends Vector> v = vcl.getConstructor(int.class);
- factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
- return new DirichletState(factory, numModels, alpha0);
+ return new DirichletState(modelDistribution, numModels, alpha0);
}
/**
* Read the first input vector to determine the prototype size for the modelPrototype
*/
- private int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
+ public static int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(input.toUri(), conf);
FileStatus[] status = fs.listStatus(input, new OutputLogFilter());
@@ -248,22 +273,18 @@ public class DirichletDriver extends Abs
* Write initial state (prior distribution) to the output path directory
* @param output the output Path
* @param stateOut the state output Path
- * @param modelFactory the String class name of the modelFactory
- * @param modelPrototype the String class name of the modelPrototype
- * @param prototypeSize the int size of the modelPrototype vectors
+ * @param modelDistribution the ModelDistribution
* @param numModels the int number of models to generate
* @param alpha0 the double alpha_0 argument to the DirichletDistribution
*/
private void writeInitialState(Path output,
Path stateOut,
- String modelFactory,
- String modelPrototype,
- int prototypeSize,
+ ModelDistribution<VectorWritable> modelDistribution,
int numModels,
double alpha0) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
IOException, SecurityException, NoSuchMethodException, InvocationTargetException {
- DirichletState state = createState(modelFactory, modelPrototype, prototypeSize, numModels, alpha0);
+ DirichletState state = createState(modelDistribution, numModels, alpha0);
writeState(output, stateOut, numModels, state);
}
@@ -281,39 +302,24 @@ public class DirichletDriver extends Abs
/**
* Run an iteration using supplied arguments
*
- * @param input
- * the directory pathname for input points
- * @param stateIn
- * the directory pathname for input state
- * @param stateOut
- * the directory pathname for output state
- * @param modelFactory
- * the class name of the model factory class
- * @param modelPrototype
- * the class name of the model prototype (a Vector implementation)
- * @param prototypeSize
- * the size of the model prototype vector
- * @param numClusters
- * the number of clusters
- * @param alpha0
- * alpha_0
- * @param numReducers
- * the number of Reducers desired
+ * @param input the directory pathname for input points
+ * @param stateIn the directory pathname for input state
+ * @param stateOut the directory pathname for output state
+ * @param modelDistribution the ModelDistribution
+ * @param numClusters the number of clusters
+ * @param alpha0 alpha_0
+ * @param numReducers the number of Reducers desired
*/
private void runIteration(Path input,
Path stateIn,
Path stateOut,
- String modelFactory,
- String modelPrototype,
- int prototypeSize,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
double alpha0,
int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
Configuration conf = new Configuration();
conf.set(STATE_IN_KEY, stateIn.toString());
- conf.set(MODEL_FACTORY_KEY, modelFactory);
- conf.set(MODEL_PROTOTYPE_KEY, modelPrototype);
- conf.set(PROTOTYPE_SIZE_KEY, Integer.toString(prototypeSize));
+ conf.set(MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
conf.set(NUM_CLUSTERS_KEY, Integer.toString(numClusters));
conf.set(ALPHA_0_KEY, Double.toString(alpha0));
@@ -344,9 +350,7 @@ public class DirichletDriver extends Abs
* the directory Path for input points
* @param output
* the directory Path for output points
- * @param modelFactory
- * the String ModelDistribution class name to use
- * @param modelPrototype
+ * @param modelDistribution
* the String class name of the model's prototype vector
* @param numClusters
* the number of models to iterate over
@@ -366,8 +370,7 @@ public class DirichletDriver extends Abs
*/
public void job(Path input,
Path output,
- String modelFactory,
- String modelPrototype,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
int maxIterations,
double alpha0,
@@ -379,8 +382,7 @@ public class DirichletDriver extends Abs
ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InterruptedException {
Path clustersOut = buildClusters(input,
output,
- modelFactory,
- modelPrototype,
+ modelDistribution,
numClusters,
maxIterations,
alpha0,
@@ -398,9 +400,7 @@ public class DirichletDriver extends Abs
* the directory Path for input points
* @param output
* the directory Path for output points
- * @param modelFactory
- * the String ModelDistribution class name to use
- * @param modelPrototype
+ * @param modelDistribution
* the String class name of the model's prototype vector
* @param numClusters
* the number of models to iterate over
@@ -415,8 +415,7 @@ public class DirichletDriver extends Abs
*/
public Path buildClusters(Path input,
Path output,
- String modelFactory,
- String modelPrototype,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
int maxIterations,
double alpha0,
@@ -424,47 +423,24 @@ public class DirichletDriver extends Abs
boolean runSequential) throws IOException, InstantiationException, IllegalAccessException,
ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InterruptedException {
Path clustersIn = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
-
- int protoSize = readPrototypeSize(input);
-
- writeInitialState(output, clustersIn, modelFactory, modelPrototype, protoSize, numClusters, alpha0);
+ writeInitialState(output, clustersIn, modelDistribution, numClusters, alpha0);
if (runSequential) {
- clustersIn = buildClustersSeq(input,
- output,
- modelFactory,
- modelPrototype,
- numClusters,
- maxIterations,
- alpha0,
- numReducers,
- clustersIn,
- protoSize);
+ clustersIn = buildClustersSeq(input, output, modelDistribution, numClusters, maxIterations, alpha0, numReducers, clustersIn);
} else {
- clustersIn = buildClustersMR(input,
- output,
- modelFactory,
- modelPrototype,
- numClusters,
- maxIterations,
- alpha0,
- numReducers,
- clustersIn,
- protoSize);
+ clustersIn = buildClustersMR(input, output, modelDistribution, numClusters, maxIterations, alpha0, numReducers, clustersIn);
}
return clustersIn;
}
private Path buildClustersSeq(Path input,
Path output,
- String modelFactory,
- String modelPrototype,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
int maxIterations,
double alpha0,
int numReducers,
- Path clustersIn,
- int protoSize) throws IOException, ClassNotFoundException, InstantiationException,
+ Path clustersIn) throws IOException, ClassNotFoundException, InstantiationException,
IllegalAccessException, NoSuchMethodException, InvocationTargetException {
for (int iteration = 1; iteration <= maxIterations; iteration++) {
log.info("Iteration {}", iteration);
@@ -472,10 +448,8 @@ public class DirichletDriver extends Abs
Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
DirichletState state = DirichletMapper.loadState(new Configuration(),
clustersIn.toString(),
- modelFactory,
- modelPrototype,
+ modelDistribution,
alpha0,
- protoSize,
numClusters);
Cluster[] newModels = (Cluster[]) state.getModelFactory().sampleFromPosterior(state.getModels());
DirichletClusterer clusterer = new DirichletClusterer(state);
@@ -506,19 +480,17 @@ public class DirichletDriver extends Abs
private Path buildClustersMR(Path input,
Path output,
- String modelFactory,
- String modelPrototype,
+ ModelDistribution<VectorWritable> modelDistribution,
int numClusters,
int maxIterations,
double alpha0,
int numReducers,
- Path clustersIn,
- int protoSize) throws IOException, InterruptedException, ClassNotFoundException {
+ Path clustersIn) throws IOException, InterruptedException, ClassNotFoundException {
for (int iteration = 1; iteration <= maxIterations; iteration++) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
- runIteration(input, clustersIn, clustersOut, modelFactory, modelPrototype, protoSize, numClusters, alpha0, numReducers);
+ runIteration(input, clustersIn, clustersOut, modelDistribution, numClusters, alpha0, numReducers);
// now point the input to the old output directory
clustersIn = clustersOut;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Fri Aug 20 21:56:16 2010
@@ -29,17 +29,22 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.JsonModelDistributionAdapter;
+import org.apache.mahout.clustering.ModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
import org.apache.mahout.clustering.kmeans.OutputLogFilter;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
public class DirichletMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
private DirichletClusterer clusterer;
@Override
- protected void map(WritableComparable<?> key, VectorWritable v, Context context)
- throws IOException, InterruptedException {
+ protected void map(WritableComparable<?> key, VectorWritable v, Context context) throws IOException, InterruptedException {
int k = clusterer.assignToModel(v);
context.write(new Text(String.valueOf(k)), v);
}
@@ -72,23 +77,19 @@ public class DirichletMapper extends Map
this.clusterer = new DirichletClusterer(state);
}
- public static DirichletState getDirichletState(Configuration conf) throws NoSuchMethodException,
- InvocationTargetException {
+ public static DirichletState getDirichletState(Configuration conf) throws NoSuchMethodException, InvocationTargetException {
String statePath = conf.get(DirichletDriver.STATE_IN_KEY);
- String modelFactory = conf.get(DirichletDriver.MODEL_FACTORY_KEY);
- String modelPrototype = conf.get(DirichletDriver.MODEL_PROTOTYPE_KEY);
- String prototypeSize = conf.get(DirichletDriver.PROTOTYPE_SIZE_KEY);
+ String json = conf.get(DirichletDriver.MODEL_DISTRIBUTION_KEY);
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+ Gson gson = builder.create();
+ ModelDistribution<VectorWritable> modelDistribution = gson.fromJson(json,
+ AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
String numClusters = conf.get(DirichletDriver.NUM_CLUSTERS_KEY);
String alpha0 = conf.get(DirichletDriver.ALPHA_0_KEY);
try {
- return loadState(conf,
- statePath,
- modelFactory,
- modelPrototype,
- Double.parseDouble(alpha0),
- Integer.parseInt(prototypeSize),
- Integer.parseInt(numClusters));
+ return loadState(conf, statePath, modelDistribution, Double.parseDouble(alpha0), Integer.parseInt(numClusters));
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
} catch (InstantiationException e) {
@@ -101,15 +102,12 @@ public class DirichletMapper extends Map
}
protected static DirichletState loadState(Configuration conf,
- String statePath,
- String modelFactory,
- String modelPrototype,
- double alpha,
- int pSize,
- int k)
- throws ClassNotFoundException, InstantiationException, IllegalAccessException,
+ String statePath,
+ ModelDistribution<VectorWritable> modelDistribution,
+ double alpha,
+ int k) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
NoSuchMethodException, InvocationTargetException, IOException {
- DirichletState state = DirichletDriver.createState(modelFactory, modelPrototype, pSize, k, alpha);
+ DirichletState state = DirichletDriver.createState(modelDistribution, k, alpha);
Path path = new Path(statePath);
FileSystem fs = FileSystem.get(path.toUri(), conf);
FileStatus[] status = fs.listStatus(path, new OutputLogFilter());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java Fri Aug 20 21:56:16 2010
@@ -16,29 +16,53 @@
package org.apache.mahout.clustering.dirichlet.models;
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.JsonDistanceMeasureAdapter;
+import org.apache.mahout.clustering.JsonModelDistributionAdapter;
import org.apache.mahout.clustering.ModelDistribution;
+import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.VectorWritable;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
public abstract class AbstractVectorModelDistribution implements ModelDistribution<VectorWritable> {
+ public static final Type MODEL_DISTRIBUTION_TYPE = new TypeToken<ModelDistribution<VectorWritable>>() {
+ }.getType();
+
// a prototype instance used for creating prior model distributions using like(). It
// should be of the class and cardinality desired for the particular application.
private VectorWritable modelPrototype;
protected AbstractVectorModelDistribution() {
}
-
+
protected AbstractVectorModelDistribution(VectorWritable modelPrototype) {
this.modelPrototype = modelPrototype;
}
+ /* (non-Javadoc)
+ * @see org.apache.mahout.clustering.ModelDistribution#asJsonString()
+ */
+ @Override
+ public String asJsonString() {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ return gson.toJson(this, MODEL_DISTRIBUTION_TYPE);
+ }
+
/**
* @return the modelPrototype
*/
public VectorWritable getModelPrototype() {
return modelPrototype;
}
-
+
/**
* @param modelPrototype
* the modelPrototype to set
@@ -46,5 +70,5 @@ public abstract class AbstractVectorMode
public void setModelPrototype(VectorWritable modelPrototype) {
this.modelPrototype = modelPrototype;
}
-
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Fri Aug 20 21:56:16 2010
@@ -24,8 +24,8 @@ import java.lang.reflect.Type;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java Fri Aug 20 21:56:16 2010
@@ -19,24 +19,30 @@ package org.apache.mahout.clustering.dir
import org.apache.mahout.clustering.DistanceMeasureCluster;
import org.apache.mahout.clustering.Model;
+import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
/**
* An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm.
- * Uses a Normal Distribution to sample the prior model values. Model values have a vector standard deviation,
- * allowing assymetrical regions to be covered by a model.
+ * Models use a DistanceMeasure to calculate pdf values.
*/
public class DistanceMeasureClusterDistribution extends AbstractVectorModelDistribution {
- ManhattanDistanceMeasure measure = new ManhattanDistanceMeasure();
+ DistanceMeasure measure;
public DistanceMeasureClusterDistribution() {
}
public DistanceMeasureClusterDistribution(VectorWritable modelPrototype) {
super(modelPrototype);
+ this.measure = new ManhattanDistanceMeasure();
+ }
+
+ public DistanceMeasureClusterDistribution(VectorWritable modelPrototype, DistanceMeasure measure) {
+ super(modelPrototype);
+ this.measure = measure;
}
@Override
@@ -59,4 +65,12 @@ public class DistanceMeasureClusterDistr
return result;
}
+ public void setMeasure(DistanceMeasure measure) {
+ this.measure = measure;
+ }
+
+ public DistanceMeasure getMeasure() {
+ return measure;
+ }
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Fri Aug 20 21:56:16 2010
@@ -23,8 +23,8 @@ import java.lang.reflect.Type;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.Vector;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Fri Aug 20 21:56:16 2010
@@ -25,8 +25,8 @@ import java.util.Locale;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Fri Aug 20 21:56:16 2010
@@ -62,6 +62,10 @@ public class FuzzyKMeansClusterer {
this.configure(conf);
}
+ public FuzzyKMeansClusterer() {
+ // TODO Auto-generated constructor stub
+ }
+
/**
* This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
* until their centers converge or until the maximum number of iterations is exceeded.
@@ -222,11 +226,7 @@ public class FuzzyKMeansClusterer {
clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get()));
}
// calculate point pdf for all clusters
- Vector pi = new DenseVector(clusters.size());
- for (int i = 0; i < clusters.size(); i++) {
- double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- pi.set(i, probWeight);
- }
+ Vector pi = computePi(clusters, clusterDistanceList);
if (emitMostLikely) {
emitMostLikelyCluster(point.get(), clusters, pi, context);
} else {
@@ -234,6 +234,15 @@ public class FuzzyKMeansClusterer {
}
}
+ public Vector computePi(List<SoftCluster> clusters, List<Double> clusterDistanceList) {
+ Vector pi = new DenseVector(clusters.size());
+ for (int i = 0; i < clusters.size(); i++) {
+ double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+ pi.set(i, probWeight);
+ }
+ return pi;
+ }
+
/**
* Emit the point to the cluster with the highest pdf
*/
@@ -302,12 +311,7 @@ public class FuzzyKMeansClusterer {
for (SoftCluster cluster : clusters) {
clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get()));
}
- // calculate point pdf for all clusters
- Vector pi = new DenseVector(clusters.size());
- for (int i = 0; i < clusters.size(); i++) {
- double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- pi.set(i, probWeight);
- }
+ Vector pi = computePi(clusters, clusterDistanceList);
if (emitMostLikely) {
emitMostLikelyCluster(point.get(), clusters, pi, writer);
} else {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Fri Aug 20 21:56:16 2010
@@ -17,9 +17,11 @@
package org.apache.mahout.clustering.fuzzykmeans;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
public class SoftCluster extends Cluster {
@@ -46,4 +48,13 @@ public class SoftCluster extends Cluster
public String getIdentifier() {
return (isConverged() ? "SV-" : "SC-") + getId();
}
+
+ /* (non-Javadoc)
+ * @see org.apache.mahout.clustering.DistanceMeasureCluster#pdf(org.apache.mahout.math.VectorWritable)
+ */
+ @Override
+ public double pdf(VectorWritable vw) {
+ // SoftCluster pdf cannot be calculated out of context. See FuzzyKMeansClusterer
+ throw new NotImplementedException();
+ }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Fri Aug 20 21:56:16 2010
@@ -22,10 +22,12 @@ import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.JsonVectorAdapter;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.list.IntArrayList;
import com.google.gson.Gson;
@@ -156,4 +158,13 @@ public class MeanShiftCanopy extends Clu
return (isConverged() ? "MSV-" : "MSC-") + getId();
}
+ /* (non-Javadoc)
+ * @see org.apache.mahout.clustering.DistanceMeasureCluster#pdf(org.apache.mahout.math.VectorWritable)
+ */
+ @Override
+ public double pdf(VectorWritable vw) {
+ // MSCanopy membership is explicit via membership in boundPoints. Can't compute pdf for Arbitrary point
+ throw new NotImplementedException();
+ }
+
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java Fri Aug 20 21:56:16 2010
@@ -22,7 +22,6 @@ import java.lang.reflect.Type;
import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.dirichlet.DirichletCluster;
import org.apache.mahout.clustering.dirichlet.JsonClusterModelAdapter;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
import org.apache.mahout.clustering.dirichlet.models.L1Model;
import org.apache.mahout.clustering.dirichlet.models.NormalModel;
Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
+public class TestModelDistributionSerialization extends MahoutTestCase {
+
+ public void testGaussianClusterDistribution() {
+ GaussianClusterDistribution dist = new GaussianClusterDistribution(new VectorWritable(new DenseVector(2)));
+ String json = dist.asJsonString();
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ GaussianClusterDistribution dist1 = (GaussianClusterDistribution) gson
+ .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+ assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ }
+
+ public void testDMClusterDistribution() {
+ DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)));
+ String json = dist.asJsonString();
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
+ .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+ assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+ }
+
+ public void testDMClusterDistribution2() {
+ DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)),
+ new EuclideanDistanceMeasure());
+ String json = dist.asJsonString();
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+ builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+ Gson gson = builder.create();
+ DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
+ .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+ assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,101 @@
+package org.apache.mahout.clustering;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class TestVectorModelClassifier extends MahoutTestCase {
+
+ public void testDMClusterClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ public void testCanopyClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new Canopy(new DenseVector(2).assign(1), 0, measure));
+ models.add(new Canopy(new DenseVector(2), 1, measure));
+ models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ public void testClusterClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new Cluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new Cluster(new DenseVector(2), 1, measure));
+ models.add(new Cluster(new DenseVector(2).assign(-1), 2, measure));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ public void testMSCanopyClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new MeanShiftCanopy(new DenseVector(2).assign(1), 0, measure));
+ models.add(new MeanShiftCanopy(new DenseVector(2), 1, measure));
+ models.add(new MeanShiftCanopy(new DenseVector(2).assign(-1), 2, measure));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ try {
+ classifier.classify(new DenseVector(2));
+ fail("Expected NotImplementedException");
+ } catch (NotImplementedException e) {
+ assertTrue(true);
+ }
+ }
+
+ public void testSoftClusterClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new SoftCluster(new DenseVector(2), 1, measure));
+ models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.000, 1.000, 0.000]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ public void testGaussianClusterClassification() {
+ List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+ models.add(new GaussianCluster(new DenseVector(2).assign(1), new DenseVector(2).assign(1), 0));
+ models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2).assign(1), 1));
+ models.add(new GaussianCluster(new DenseVector(2).assign(-1), new DenseVector(2).assign(1), 2));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.998, 0.002, 0.000]", AbstractCluster.formatVector(pdf, null));
+ }
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java Fri Aug 20 21:56:16 2010
@@ -21,8 +21,9 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
import org.apache.mahout.common.MahoutTestCase;
@@ -48,10 +49,8 @@ public class TestDirichletClustering ext
* @param sd double standard deviation of the samples
* @param card int cardinality of the generated sample vectors
*/
- private void generateSamples(int num, double mx, double my, double sd,
- int card) {
- System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
- + "] sd=" + sd);
+ private void generateSamples(int num, double mx, double my, double sd, int card) {
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my + "] sd=" + sd);
for (int i = 0; i < num; i++) {
DenseVector v = new DenseVector(card);
for (int j = 0; j < card; j++) {
@@ -72,14 +71,13 @@ public class TestDirichletClustering ext
generateSamples(num, mx, my, sd, 2);
}
- private static void printResults(List<Cluster[]> result,
- int significant) {
+ private static void printResults(List<Cluster[]> result, int significant) {
int row = 0;
- for (Model<VectorWritable>[] r : result) {
+ for (Cluster[] r : result) {
System.out.print("sample[" + row++ + "]= ");
- for (Model<VectorWritable> model : r) {
+ for (Cluster model : r) {
if (model.count() > significant) {
- System.out.print(model.toString() + ", ");
+ System.out.print(model.asFormatString(null) + ", ");
}
}
System.out.println();
@@ -93,9 +91,12 @@ public class TestDirichletClustering ext
generateSamples(30, 1, 0, 0.1);
generateSamples(30, 0, 1, 0.1);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new NormalModelDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new NormalModelDistribution(new VectorWritable(new DenseVector(2))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);
@@ -107,9 +108,12 @@ public class TestDirichletClustering ext
generateSamples(30, 1, 0, 0.1);
generateSamples(30, 0, 1, 0.1);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new SampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new SampledNormalDistribution(new VectorWritable(new DenseVector(2))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);
@@ -121,107 +125,29 @@ public class TestDirichletClustering ext
generateSamples(30, 1, 0, 0.1);
generateSamples(30, 0, 1, 0.1);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(2))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);
}
- public void testDirichletCluster1000() {
- System.out.println("testDirichletCluster1000");
- generateSamples(400, 1, 1, 3);
- generateSamples(300, 1, 0, 0.1);
- generateSamples(300, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new NormalModelDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 20);
- assertNotNull(result);
- }
-
- public void testDirichletCluster1000s() {
- System.out.println("testDirichletCluster1000s");
- generateSamples(400, 1, 1, 3);
- generateSamples(300, 1, 0, 0.1);
- generateSamples(300, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new SampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 20);
- assertNotNull(result);
- }
-
- public void testDirichletCluster1000as() {
- System.out.println("testDirichletCluster1000as");
- generateSamples(400, 1, 1, 3);
- generateSamples(300, 1, 0, 0.1);
- generateSamples(300, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 20);
- assertNotNull(result);
- }
-
- public void testDirichletCluster10000() {
- System.out.println("testDirichletCluster10000");
- generateSamples(4000, 1, 1, 3);
- generateSamples(3000, 1, 0, 0.1);
- generateSamples(3000, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new NormalModelDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 200);
- assertNotNull(result);
- }
-
- public void testDirichletCluster10000as() {
- System.out.println("testDirichletCluster10000as");
- generateSamples(4000, 1, 1, 3);
- generateSamples(3000, 1, 0, 0.1);
- generateSamples(3000, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 200);
- assertNotNull(result);
- }
-
- public void testDirichletCluster10000s() {
- System.out.println("testDirichletCluster10000s");
- generateSamples(4000, 1, 1, 3);
- generateSamples(3000, 1, 0, 0.1);
- generateSamples(3000, 0, 1, 0.1);
-
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new SampledNormalDistribution(new VectorWritable(
- new DenseVector(2))), 1.0, 10, 1, 0);
- List<Cluster[]> result = dc.cluster(30);
- printResults(result, 200);
- assertNotNull(result);
- }
-
public void testDirichletCluster100C3() {
System.out.println("testDirichletCluster100");
generateSamples(40, 1, 1, 3, 3);
generateSamples(30, 1, 0, 0.1, 3);
generateSamples(30, 0, 1, 0.1, 3);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new NormalModelDistribution(new VectorWritable(
- new DenseVector(3))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new NormalModelDistribution(new VectorWritable(new DenseVector(3))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);
@@ -233,9 +159,12 @@ public class TestDirichletClustering ext
generateSamples(30, 1, 0, 0.1, 3);
generateSamples(30, 0, 1, 0.1, 3);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new SampledNormalDistribution(new VectorWritable(
- new DenseVector(3))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new SampledNormalDistribution(new VectorWritable(new DenseVector(3))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);
@@ -247,9 +176,46 @@ public class TestDirichletClustering ext
generateSamples(30, 1, 0, 0.1, 3);
generateSamples(30, 0, 1, 0.1, 3);
- DirichletClusterer dc = new DirichletClusterer(
- sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
- new DenseVector(3))), 1.0, 10, 1, 0);
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(3))),
+ 1.0,
+ 10,
+ 1,
+ 0);
+ List<Cluster[]> result = dc.cluster(30);
+ printResults(result, 2);
+ assertNotNull(result);
+ }
+
+ public void testDirichletGaussianCluster100() {
+ System.out.println("testDirichletGaussianCluster100");
+ generateSamples(40, 1, 1, 3);
+ generateSamples(30, 1, 0, 0.1);
+ generateSamples(30, 0, 1, 0.1);
+
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))),
+ 1.0,
+ 10,
+ 1,
+ 0);
+ List<Cluster[]> result = dc.cluster(30);
+ printResults(result, 2);
+ assertNotNull(result);
+ }
+
+ public void testDirichletDMCluster100() {
+ System.out.println("testDirichletDMCluster100");
+ generateSamples(40, 1, 1, 3);
+ generateSamples(30, 1, 0, 0.1);
+ generateSamples(30, 0, 1, 0.1);
+
+ DirichletClusterer dc = new DirichletClusterer(sampleData,
+ new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2))),
+ 1.0,
+ 10,
+ 1,
+ 0);
List<Cluster[]> result = dc.cluster(30);
printResults(result, 2);
assertNotNull(result);