You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/04/16 00:38:45 UTC

svn commit: r765403 - in /lucene/mahout/trunk: core/ core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/ core/src/main/java/org/apache/mahout/clustering/meanshift/ examples/ examples/src/main/java/org/apache/mahout/clustering/dirichlet/ exampl...

Author: jeastman
Date: Wed Apr 15 22:38:44 2009
New Revision: 765403

URL: http://svn.apache.org/viewvc?rev=765403&view=rev
Log:
Added examples of MeanShift and Fuzzy K-Means operating on Dirichlet sample data

Added:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
Modified:
    lucene/mahout/trunk/core/   (props changed)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
    lucene/mahout/trunk/examples/   (props changed)
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java

Propchange: lucene/mahout/trunk/core/
------------------------------------------------------------------------------
--- svn:ignore (original)
+++ svn:ignore Wed Apr 15 22:38:44 2009
@@ -6,3 +6,6 @@
 test
 target
 *.iml
+.settings
+.classpath
+.project

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=765403&r1=765402&r2=765403&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Wed Apr 15 22:38:44 2009
@@ -26,6 +26,7 @@
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 import org.apache.mahout.utils.DistanceMeasure;
 
@@ -42,7 +43,8 @@
   private static double m = 2.0; // default value
 
   public static final double MINIMAL_VALUE = 0.0000000001; // using it for
-                                                            // adding
+
+  // adding
 
   // exception
   // this value to any
@@ -70,6 +72,13 @@
   // has the centroid converged with the center?
   private boolean converged = false;
 
+  // track membership parameters
+  double s0 = 0;
+
+  Vector s1;
+
+  Vector s2;
+
   private static DistanceMeasure measure;
 
   private static double convergenceDelta = 0;
@@ -163,9 +172,9 @@
       double probWeight = computeProbWeight(clusterDistanceList.get(i),
           clusterDistanceList);
       Text key = new Text(clusters.get(i).getIdentifier()); // just output the
-                                                            // identifier,avoids
-                                                            // too much data
-                                                            // traffic
+      // identifier,avoids
+      // too much data
+      // traffic
       Text value = new Text(Double.toString(probWeight)
           + FuzzyKMeansDriver.MAPPER_VALUE_SEPARATOR + values.toString());
       output.collect(key, value);
@@ -203,8 +212,7 @@
           probWeight).append(' ');
     }
     output.collect(new Text(outputKey.trim()), new Text(outputValue.toString()
-        .trim()
-        + ']'));
+        .trim() + ']'));
   }
 
   /**
@@ -295,12 +303,47 @@
   }
 
   /**
+   * Observe the point, accumulating weighted variables for std() calculation
+   * @param point
+   * @param ptProb
+   */
+  private void observePoint(Vector point, double ptProb) {
+    s0 += ptProb;
+    Vector wtPt = point.times(ptProb);
+    if (s1 == null)
+      s1 = point.copy();
+    else
+      s1 = s1.plus(wtPt);
+    if (s2 == null)
+      s2 = wtPt.times(wtPt);
+    else
+      s2 = s2.plus(wtPt.times(wtPt));
+  }
+
+  /**
+   * Compute a "standard deviation" value to use as the "radius" of the cluster for display purposes
+   * @return
+   */
+  public double std() {
+    if (s0 > 0) {
+      Vector radical = s2.times(s0).minus(s1.times(s1));
+      radical = radical.times(radical).assign(new SquareRootFunction());
+      Vector stds = radical.assign(new SquareRootFunction()).divide(s0);
+      double res = stds.zSum() / stds.cardinality();
+      System.out.println(res);
+      return res;
+    } else
+      return 0.33;
+  }
+
+  /**
    * Add the point to the SoftCluster
    * 
    * @param point a point to add
    * @param ptProb
    */
   public void addPoint(Vector point, double ptProb) {
+    observePoint(point, ptProb);
     centroid = null;
     pointProbSum += ptProb;
     if (weightedPointTotal == null)

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=765403&r1=765402&r2=765403&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Wed Apr 15 22:38:44 2009
@@ -57,10 +57,10 @@
   private static int nextCanopyId = 0;
 
   // the T1 distance threshold
-  private static double t1;
+  static double t1;
 
   // the T2 distance threshold
-  private static double t2;
+  static double t2;
 
   // the distance measure
   private static DistanceMeasure measure;

Propchange: lucene/mahout/trunk/examples/
------------------------------------------------------------------------------
--- svn:ignore (original)
+++ svn:ignore Wed Apr 15 22:38:44 2009
@@ -6,3 +6,6 @@
 temp
 work
 *.iml
+.settings
+.classpath
+.project

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java?rev=765403&r1=765402&r2=765403&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java Wed Apr 15 22:38:44 2009
@@ -22,11 +22,11 @@
 public class DisplayDirichlet extends Frame {
   private static final long serialVersionUID = 1L;
 
-  int res; //screen resolution
+  protected int res; //screen resolution
 
-  int ds = 72; //default scale = 72 pixels per inch
+  protected int ds = 72; //default scale = 72 pixels per inch
 
-  int size = 8; // screen size in inches
+  protected int size = 8; // screen size in inches
 
   public static List<Vector> sampleData = new ArrayList<Vector>();
 

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java?rev=765403&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java Wed Apr 15 22:38:44 2009
@@ -0,0 +1,188 @@
+package org.apache.mahout.clustering.fuzzykmeans;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.dirichlet.DisplayDirichlet;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DistanceMeasure;
+import org.apache.mahout.utils.ManhattanDistanceMeasure;
+
+class DisplayFuzzyKMeans extends DisplayDirichlet {
+  public DisplayFuzzyKMeans() {
+    initialize();
+    this.setTitle("Fuzzy K-Means Clusters (> 5% of population)");
+  }
+
+  private static final long serialVersionUID = 1L;
+
+  static List<Canopy> canopies;
+
+  static List<List<SoftCluster>> clusters;
+
+  static double t1 = 3.0;
+
+  static double t2 = 1.5;
+
+  public void paint(Graphics g) {
+    super.plotSampleData(g);
+    Graphics2D g2 = (Graphics2D) g;
+    Vector dv = new DenseVector(2);
+    int i = clusters.size() - 1;
+    for (List<SoftCluster> cls : clusters) {
+      g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+      g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+      for (SoftCluster cluster : cls)
+        if (true || cluster.getWeightedPointTotal().zSum() > sampleData.size() * 0.05) {
+          dv.assign(cluster.std() * 3);
+          plotEllipse(g2, cluster.getCenter(), dv);
+        }
+    }
+  }
+
+  public static void referenceFuzzyKMeans(List<Vector> points,
+      DistanceMeasure measure, double threshold, int numIter) throws Exception {
+    SoftCluster.config(measure, threshold);
+    boolean converged = false;
+    int iteration = 0;
+    for (int iter = 0; !converged && iter < numIter; iter++) {
+      List<SoftCluster> next = new ArrayList<SoftCluster>();
+      List<SoftCluster> cs = clusters.get(iteration++);
+      for (SoftCluster c : cs)
+        next.add(new SoftCluster(c.getCenter()));
+      clusters.add(next);
+      converged = iterateReference(points, clusters.get(iteration), measure);
+    }
+  }
+
+  /**
+   * Perform a single iteration over the points and clusters, assigning points
+   * to clusters and returning if the iterations are completed.
+   * 
+   * @param points the List<Vector> having the input points
+   * @param clusters the List<Cluster> clusters
+   * @param measure a DistanceMeasure to use
+   * @return
+   */
+  public static boolean iterateReference(List<Vector> points,
+      List<SoftCluster> clusterList, DistanceMeasure measure) {
+    // for each
+    for (Vector point : points) {
+      List<Double> clusterDistanceList = new ArrayList<Double>();
+      for (SoftCluster cluster : clusterList) {
+        clusterDistanceList.add(measure.distance(point, cluster.getCenter()));
+      }
+
+      for (int i = 0; i < clusterList.size(); i++) {
+        double probWeight = SoftCluster.computeProbWeight(clusterDistanceList
+            .get(i), clusterDistanceList);
+        clusterList.get(i).addPoint(point,
+            Math.pow(probWeight, SoftCluster.getM()));
+      }
+    }
+    boolean converged = true;
+    for (SoftCluster cluster : clusterList) {
+      if (!cluster.computeConvergence())
+        converged = false;
+    }
+    // update the cluster centers
+    if (!converged)
+      for (SoftCluster cluster : clusterList)
+        cluster.recomputeCenter();
+    return converged;
+
+  }
+
+  /**
+   * Iterate through the points, adding new canopies. Return the canopies.
+   * 
+   * @param measure
+   *            a DistanceMeasure to use
+   * @param points
+   *            a list<Vector> defining the points to be clustered
+   * @param t1
+   *            the T1 distance threshold
+   * @param t2
+   *            the T2 distance threshold
+   * @return the List<Canopy> created
+   */
+  static List<Canopy> populateCanopies(DistanceMeasure measure,
+      List<Vector> points, double t1, double t2) {
+    List<Canopy> canopies = new ArrayList<Canopy>();
+    Canopy.config(measure, t1, t2);
+    /**
+     * Reference Implementation: Given a distance metric, one can create
+     * canopies as follows: Start with a list of the data points in any order,
+     * and with two distance thresholds, T1 and T2, where T1 > T2. (These
+     * thresholds can be set by the user, or selected by cross-validation.) Pick
+     * a point on the list and measure its distance to all other points. Put all
+     * points that are within distance threshold T1 into a canopy. Remove from
+     * the list all points that are within distance threshold T2. Repeat until
+     * the list is empty.
+     */
+    while (!points.isEmpty()) {
+      Iterator<Vector> ptIter = points.iterator();
+      Vector p1 = ptIter.next();
+      ptIter.remove();
+      Canopy canopy = new Canopy(p1);
+      canopies.add(canopy);
+      while (ptIter.hasNext()) {
+        Vector p2 = ptIter.next();
+        double dist = measure.distance(p1, p2);
+        // Put all points that are within distance threshold T1 into the canopy
+        if (dist < t1)
+          canopy.addPoint(p2);
+        // Remove from the list all points that are within distance threshold T2
+        if (dist < t2)
+          ptIter.remove();
+      }
+    }
+    return canopies;
+  }
+
+  public static void main(String[] args) {
+    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+    generateSamples();
+    List<Vector> points = new ArrayList<Vector>();
+    points.addAll(sampleData);
+    canopies = populateCanopies(new ManhattanDistanceMeasure(), points, t1, t2);
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    Cluster.config(measure, 0.001);
+    clusters = new ArrayList<List<SoftCluster>>();
+    clusters.add(new ArrayList<SoftCluster>());
+    for (Canopy canopy : canopies)
+      if (canopy.getNumPoints() > 0.05 * sampleData.size())
+        clusters.get(0).add(new SoftCluster(canopy.getCenter()));
+    try {
+      referenceFuzzyKMeans(sampleData, measure, 0.001, 10);
+    } catch (Exception e) {
+      e.printStackTrace();
+    }
+    new DisplayFuzzyKMeans();
+  }
+}

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java?rev=765403&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java Wed Apr 15 22:38:44 2009
@@ -0,0 +1,108 @@
+package org.apache.mahout.clustering.meanshift;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.Color;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.awt.geom.AffineTransform;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.mahout.clustering.dirichlet.DisplayDirichlet;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.EuclideanDistanceMeasure;
+
+class DisplayMeanShift extends DisplayDirichlet {
+  public DisplayMeanShift() {
+    initialize();
+    this.setTitle("Canopy Clusters (> 1.5% of population)");
+  }
+
+  private static final long serialVersionUID = 1L;
+
+  private static List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+
+  private static List<List<Vector>> iterationCenters = new ArrayList<List<Vector>>();
+
+  public void paint(Graphics g) {
+    Graphics2D g2 = (Graphics2D) g;
+    double sx = (double) res / ds;
+    g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
+
+    // plot the axes
+    g2.setColor(Color.BLACK);
+    Vector dv = new DenseVector(2).assign(size / 2);
+    Vector dv1 = new DenseVector(2).assign(MeanShiftCanopy.t1);
+    Vector dv2 = new DenseVector(2).assign(MeanShiftCanopy.t2);
+    plotRectangle(g2, new DenseVector(2).assign(2), dv);
+    plotRectangle(g2, new DenseVector(2).assign(-2), dv);
+
+    // plot the sample data
+    g2.setColor(Color.DARK_GRAY);
+    dv.assign(0.03);
+    for (Vector v : sampleData)
+      plotRectangle(g2, v, dv);
+    int i = 0;
+    for (MeanShiftCanopy canopy : canopies)
+      if (canopy.getBoundPoints().size() > 0.015 * sampleData.size()) {
+        g2.setColor(colors[Math.min(i++, colors.length - 1)]);
+        for (Vector v : canopy.getBoundPoints())
+          plotRectangle(g2, v, dv);
+        plotEllipse(g2, canopy.getCenter(), dv1);
+        plotEllipse(g2, canopy.getCenter(), dv2);
+      }
+  }
+
+  public static void testReferenceImplementation() {
+    MeanShiftCanopy.config(new EuclideanDistanceMeasure(), 1.0, 0.05, 0.5);
+    // add all points to the canopies
+    for (Vector aRaw : sampleData) {
+      MeanShiftCanopy.mergeCanopy(new MeanShiftCanopy(aRaw), canopies);
+    }
+    boolean done = false;
+    while (!done) {// shift canopies to their centroids
+      done = true;
+      List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+      List<Vector> centers = new ArrayList<Vector>();
+      for (MeanShiftCanopy canopy : canopies) {
+        centers.add(canopy.getCenter());
+        done = canopy.shiftToMean() && done;
+        MeanShiftCanopy.mergeCanopy(canopy, migratedCanopies);
+      }
+      iterationCenters.add(centers);
+      canopies = migratedCanopies;
+    }
+  }
+
+  public static void main(String[] args) {
+    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+    generateSamples();
+    testReferenceImplementation();
+    for (MeanShiftCanopy canopy : canopies)
+      System.out.println(canopy.toString());
+    new DisplayMeanShift();
+  }
+
+  static void generateResults() {
+    DisplayDirichlet.generateResults(new NormalModelDistribution());
+  }
+}