You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:08:02 UTC
[31/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
new file mode 100644
index 0000000..ebb0614
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.Set;
+
+/**
+ * Uses sample data to reverse engineer a feature-hashed model.
+ *
+ * The result gives approximate weights for features and interactions
+ * in the original space.
+ *
+ * The idea is that the hashed encoders have the option of having a trace dictionary. This
+ * tells us where each feature is hashed to, or each feature/value combination in the case
+ * of word-like values. Using this dictionary, we can put values into a synthetic feature
+ * vector in just the locations specified by a single feature or interaction. Then we can
+ * push this through a linear part of a model to see the contribution of that input. For
+ * any generalized linear model like logistic regression, there is a linear part of the
+ * model that allows this.
+ *
+ * What the ModelDissector does is to accept a trace dictionary and a model in an update
+ * method. It figures out the weights for the elements in the trace dictionary and stashes
+ * them. Then in a summary method, the biggest weights are returned. This update/flush
+ * style is used so that the trace dictionary doesn't have to grow to enormous levels,
+ * but instead can be cleared between updates.
+ */
+public class ModelDissector {
+ private final Map<String,Vector> weightMap;
+
+ public ModelDissector() {
+ weightMap = Maps.newHashMap();
+ }
+
+ /**
+ * Probes a model to determine the effect of a particular variable. This is done
+ * with the ade of a trace dictionary which has recorded the locations in the feature
+ * vector that are modified by various variable values. We can set these locations to
+ * 1 and then look at the resulting score. This tells us the weight the model places
+ * on that variable.
+ * @param features A feature vector to use (destructively)
+ * @param traceDictionary A trace dictionary containing variables and what locations
+ * in the feature vector are affected by them
+ * @param learner The model that we are probing to find weights on features
+ */
+
+ public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
+ // zero out feature vector
+ features.assign(0);
+ for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+ // get a feature and locations where it is stored in the feature vector
+ String key = entry.getKey();
+ Set<Integer> value = entry.getValue();
+
+ // if we haven't looked at this feature yet
+ if (!weightMap.containsKey(key)) {
+ // put probe values in the feature vector
+ for (Integer where : value) {
+ features.set(where, 1);
+ }
+
+ // see what the model says
+ Vector v = learner.classifyNoLink(features);
+ weightMap.put(key, v);
+
+ // and zero out those locations again
+ for (Integer where : value) {
+ features.set(where, 0);
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the n most important features with their
+ * weights, most important category and the top few
+ * categories that they affect.
+ * @param n How many results to return.
+ * @return A list of the top variables.
+ */
+ public List<Weight> summary(int n) {
+ Queue<Weight> pq = new PriorityQueue<Weight>();
+ for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
+ pq.add(new Weight(entry.getKey(), entry.getValue()));
+ while (pq.size() > n) {
+ pq.poll();
+ }
+ }
+ List<Weight> r = Lists.newArrayList(pq);
+ Collections.sort(r, Ordering.natural().reverse());
+ return r;
+ }
+
+ private static final class Category implements Comparable<Category> {
+ private final int index;
+ private final double weight;
+
+ private Category(int index, double weight) {
+ this.index = index;
+ this.weight = weight;
+ }
+
+ @Override
+ public int compareTo(Category o) {
+ int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
+ if (r == 0) {
+ if (o.index < index) {
+ return -1;
+ }
+ if (o.index > index) {
+ return 1;
+ }
+ return 0;
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Category)) {
+ return false;
+ }
+ Category other = (Category) o;
+ return index == other.index && weight == other.weight;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashDouble(weight) ^ index;
+ }
+
+ }
+
+ public static class Weight implements Comparable<Weight> {
+ private final String feature;
+ private final double value;
+ private final int maxIndex;
+ private final List<Category> categories;
+
+ public Weight(String feature, Vector weights) {
+ this(feature, weights, 3);
+ }
+
+ public Weight(String feature, Vector weights, int n) {
+ this.feature = feature;
+ // pick out the weight with the largest abs value, but don't forget the sign
+ Queue<Category> biggest = new PriorityQueue<Category>(n + 1, Ordering.natural());
+ for (Vector.Element element : weights.all()) {
+ biggest.add(new Category(element.index(), element.get()));
+ while (biggest.size() > n) {
+ biggest.poll();
+ }
+ }
+ categories = Lists.newArrayList(biggest);
+ Collections.sort(categories, Ordering.natural().reverse());
+ value = categories.get(0).weight;
+ maxIndex = categories.get(0).index;
+ }
+
+ @Override
+ public int compareTo(Weight other) {
+ int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
+ if (r == 0) {
+ return feature.compareTo(other.feature);
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Weight)) {
+ return false;
+ }
+ Weight other = (Weight) o;
+ return feature.equals(other.feature)
+ && value == other.value
+ && maxIndex == other.maxIndex
+ && categories.equals(other.categories);
+ }
+
+ @Override
+ public int hashCode() {
+ return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode();
+ }
+
+ public String getFeature() {
+ return feature;
+ }
+
+ public double getWeight() {
+ return value;
+ }
+
+ public double getWeight(int n) {
+ return categories.get(n).weight;
+ }
+
+ public double getCategory(int n) {
+ return categories.get(n).index;
+ }
+
+ public int getMaxImpact() {
+ return maxIndex;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
new file mode 100644
index 0000000..f0150e9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Provides the ability to store SGD model-related objects as binary files.
+ */
+public final class ModelSerializer {
+
+ // static class ... don't instantiate
+ private ModelSerializer() {
+ }
+
+ public static void writeBinary(String path, CrossFoldLearner model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static void writeBinary(String path, OnlineLogisticRegression model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static void writeBinary(String path, AdaptiveLogisticRegression model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException {
+ DataInput dataIn = new DataInputStream(in);
+ try {
+ return PolymorphicWritable.read(dataIn, clazz);
+ } finally {
+ Closeables.close(in, false);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
new file mode 100644
index 0000000..7a9ca83
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Extends the basic on-line logistic regression learner with a specific set of learning
+ * rate annealing schedules.
+ */
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression implements Writable {
+ public static final int WRITABLE_VERSION = 1;
+
+ // these next two control decayFactor^steps exponential type of annealing
+ // learning rate and decay factor
+ private double mu0 = 1;
+ private double decayFactor = 1 - 1.0e-3;
+
+ // these next two control 1/steps^forget type annealing
+ private int stepOffset = 10;
+ // -1 equals even weighting of all examples, 0 means only use exponential annealing
+ private double forgettingExponent = -0.5;
+
+ // controls how per term annealing works
+ private int perTermAnnealingOffset = 20;
+
+ public OnlineLogisticRegression() {
+ // private constructor available for serialization, but not normal use
+ }
+
+ public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+ this.numCategories = numCategories;
+ this.prior = prior;
+
+ updateSteps = new DenseVector(numFeatures);
+ updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset);
+ beta = new DenseMatrix(numCategories - 1, numFeatures);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param alpha New value of decayFactor, the exponential decay rate for the learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression alpha(double alpha) {
+ this.decayFactor = alpha;
+ return this;
+ }
+
+ @Override
+ public OnlineLogisticRegression lambda(double lambda) {
+ // we only over-ride this to provide a more restrictive return type
+ super.lambda(lambda);
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression learningRate(double learningRate) {
+ this.mu0 = learningRate;
+ return this;
+ }
+
+ public OnlineLogisticRegression stepOffset(int stepOffset) {
+ this.stepOffset = stepOffset;
+ return this;
+ }
+
+ public OnlineLogisticRegression decayExponent(double decayExponent) {
+ if (decayExponent > 0) {
+ decayExponent = -decayExponent;
+ }
+ this.forgettingExponent = decayExponent;
+ return this;
+ }
+
+
+ @Override
+ public double perTermLearningRate(int j) {
+ return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j));
+ }
+
+ @Override
+ public double currentLearningRate() {
+ return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent);
+ }
+
+ public void copyFrom(OnlineLogisticRegression other) {
+ super.copyFrom(other);
+ mu0 = other.mu0;
+ decayFactor = other.decayFactor;
+
+ stepOffset = other.stepOffset;
+ forgettingExponent = other.forgettingExponent;
+
+ perTermAnnealingOffset = other.perTermAnnealingOffset;
+ }
+
+ public OnlineLogisticRegression copy() {
+ close();
+ OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior);
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(mu0);
+ out.writeDouble(getLambda());
+ out.writeDouble(decayFactor);
+ out.writeInt(stepOffset);
+ out.writeInt(step);
+ out.writeDouble(forgettingExponent);
+ out.writeInt(perTermAnnealingOffset);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, beta);
+ PolymorphicWritable.write(out, prior);
+ VectorWritable.writeVector(out, updateCounts);
+ VectorWritable.writeVector(out, updateSteps);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ mu0 = in.readDouble();
+ lambda(in.readDouble());
+ decayFactor = in.readDouble();
+ stepOffset = in.readInt();
+ step = in.readInt();
+ forgettingExponent = in.readDouble();
+ perTermAnnealingOffset = in.readInt();
+ numCategories = in.readInt();
+ beta = MatrixWritable.readMatrix(in);
+ prior = PolymorphicWritable.read(in, PriorFunction.class);
+
+ updateCounts = VectorWritable.readVector(in);
+ updateSteps = VectorWritable.readVector(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
new file mode 100644
index 0000000..c51361c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Online passive aggressive learner that tries to minimize the label ranking hinge loss.
+ * Implements a multi-class linear classifier minimizing rank loss.
+ * based on "Online passive aggressive algorithms" by Cramer et al, 2006.
+ * Note: Its better to use classifyNoLink because the loss function is based
+ * on ensuring that the score of the good label is larger than the next
+ * highest label by some margin. The conversion to probability is just done
+ * by exponentiating and dividing by the sum and is empirical at best.
+ * Your features should be pre-normalized in some sensible range, for example,
+ * by subtracting the mean and standard deviation, if they are very
+ * different in magnitude from each other.
+ */
+public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // loss statistics.
+ private int lossCount = 0;
+ private double lossSum = 0;
+
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories ) x numFeatures
+ private Matrix weights;
+
+ // number of categories we are classifying.
+ private int numCategories;
+
+ public PassiveAggressive(int numCategories, int numFeatures) {
+ this.numCategories = numCategories;
+ weights = new DenseMatrix(numCategories, numFeatures);
+ weights.assign(0.0);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public PassiveAggressive learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ public void copyFrom(PassiveAggressive other) {
+ learningRate = other.learningRate;
+ numCategories = other.numCategories;
+ weights = other.weights;
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = classifyNoLink(instance);
+ // Convert to probabilities by exponentiation.
+ double max = result.maxValue();
+ result.assign(Functions.minus(max)).assign(Functions.EXP);
+ result = result.divide(result.norm(1));
+
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector result = new DenseVector(weights.numRows());
+ result.assign(0);
+ for (int i = 0; i < weights.numRows(); i++) {
+ result.setQuick(i, weights.viewRow(i).dot(instance));
+ }
+ return result;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double v1 = weights.viewRow(0).dot(instance);
+ double v2 = weights.viewRow(1).dot(instance);
+ v1 = Math.exp(v1);
+ v2 = Math.exp(v2);
+ return v2 / (v1 + v2);
+ }
+
+ public int numFeatures() {
+ return weights.numCols();
+ }
+
+ public PassiveAggressive copy() {
+ close();
+ PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, weights);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ numCategories = in.readInt();
+ weights = MatrixWritable.readMatrix(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ if (lossCount > 1000) {
+ log.info("Avg. Loss = {}", lossSum / lossCount);
+ lossCount = 0;
+ lossSum = 0;
+ }
+ Vector result = classifyNoLink(instance);
+ double myScore = result.get(actual);
+ // Find the highest score that is not actual.
+ int otherIndex = result.maxValueIndex();
+ double otherValue = result.get(otherIndex);
+ if (otherIndex == actual) {
+ result.setQuick(otherIndex, Double.NEGATIVE_INFINITY);
+ otherIndex = result.maxValueIndex();
+ otherValue = result.get(otherIndex);
+ }
+ double loss = 1.0 - myScore + otherValue;
+ lossCount += 1;
+ if (loss >= 0) {
+ lossSum += loss;
+ double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
+ Vector delta = instance.clone();
+ delta.assign(Functions.mult(tau));
+ weights.viewRow(actual).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(actual));
+ delta.assign(Functions.mult(-1));
+ weights.viewRow(otherIndex).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(otherIndex));
+ }
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
new file mode 100644
index 0000000..90062a6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.ClassUtils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utilities that write a class name and then serialize using writables.
+ */
+public final class PolymorphicWritable {
+
+ private PolymorphicWritable() {
+ }
+
+ public static <T extends Writable> void write(DataOutput dataOutput, T value) throws IOException {
+ dataOutput.writeUTF(value.getClass().getName());
+ value.write(dataOutput);
+ }
+
+ public static <T extends Writable> T read(DataInput dataInput, Class<? extends T> clazz) throws IOException {
+ String className = dataInput.readUTF();
+ T r = ClassUtils.instantiateAs(className, clazz);
+ r.readFields(dataInput);
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
new file mode 100644
index 0000000..857f061
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * A prior is used to regularize the learning algorithm. This allows a trade-off to
+ * be made between complexity of the model being learned and the accuracy with which
+ * the model fits the training data. There are different definitions of complexity
+ * which can be approximated using different priors. For large sparse systems, such
+ * as text classification, the L1 prior is often used which favors sparse models.
+ */
+public interface PriorFunction extends Writable {
+ /**
+ * Applies the regularization to a coefficient.
+ * @param oldValue The previous value.
+ * @param generations The number of generations.
+ * @param learningRate The learning rate with lambda baked in.
+ * @return The new coefficient value after regularization.
+ */
+ double age(double oldValue, double generations, double learningRate);
+
+ /**
+ * Returns the log of the probability of a particular coefficient value according to the prior.
+ * @param betaIJ The coefficient.
+ * @return The log probability.
+ */
+ double logP(double betaIJ);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
new file mode 100644
index 0000000..b52cb8c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * Uses the difference between this instance and recent history to get a
+ * gradient that optimizes ranking performance. Essentially this is the
+ * same as directly optimizing AUC. It isn't expected that this would
+ * be used alone, but rather that a MixedGradient would use it and a
+ * DefaultGradient together to combine both ranking and log-likelihood
+ * goals.
+ */
+public class RankingGradient implements Gradient {
+
+ private static final Gradient BASIC = new DefaultGradient();
+
+ private int window = 10;
+
+ private final List<Deque<Vector>> history = Lists.newArrayList();
+
+ public RankingGradient(int window) {
+ this.window = window;
+ }
+
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ addToHistory(actual, instance);
+
+ // now compute average gradient versus saved vectors from the other side
+ Deque<Vector> otherSide = history.get(1 - actual);
+ int n = otherSide.size();
+
+ Vector r = null;
+ for (Vector other : otherSide) {
+ Vector g = BASIC.apply(groupKey, actual, instance.minus(other), classifier);
+
+ if (r == null) {
+ r = g;
+ } else {
+ r.assign(g, Functions.plusMult(1.0 / n));
+ }
+ }
+ return r;
+ }
+
+ public void addToHistory(int actual, Vector instance) {
+ while (history.size() <= actual) {
+ history.add(new ArrayDeque<Vector>(window));
+ }
+ // save this instance
+ Deque<Vector> ourSide = history.get(actual);
+ ourSide.add(instance);
+ while (ourSide.size() >= window) {
+ ourSide.pollFirst();
+ }
+ }
+
+ public Gradient getBaseGradient() {
+ return BASIC;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
new file mode 100644
index 0000000..fbc825d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * A record factor understands how to convert a line of data into fields and then into a vector.
+ */
+public interface RecordFactory {
+ void defineTargetCategories(List<String> values);
+
+ RecordFactory maxTargetValue(int max);
+
+ boolean usesFirstLineAsSchema();
+
+ int processLine(String line, Vector featureVector);
+
+ Iterable<String> getPredictors();
+
+ Map<String, Set<Integer>> getTraceDictionary();
+
+ RecordFactory includeBiasTerm(boolean useBias);
+
+ List<String> getTargetCategories();
+
+ void firstLine(String line);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
new file mode 100644
index 0000000..0a7b6a7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.math3.special.Gamma;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Provides a t-distribution as a prior.
+ */
+public class TPrior implements PriorFunction {
+ private double df;
+
+ public TPrior(double df) {
+ this.df = df;
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ for (int i = 0; i < generations; i++) {
+ oldValue -= learningRate * oldValue * (df + 1.0) / (df + oldValue * oldValue);
+ }
+ return oldValue;
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return Gamma.logGamma((df + 1.0) / 2.0)
+ - Math.log(df * Math.PI)
+ - Gamma.logGamma(df / 2.0)
+ - (df + 1.0) / 2.0 * Math.log1p(betaIJ * betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(df);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ df = in.readDouble();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
new file mode 100644
index 0000000..23c812f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * A uniform prior. This is an improper prior that corresponds to no regularization at all.
+ */
+public class UniformPrior implements PriorFunction {
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ return oldValue;
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return 0;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ // nothing to write
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ // stateless class is trivial to read
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
new file mode 100644
index 0000000..c2ad966
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
@@ -0,0 +1,23 @@
+/**
+ * <p>Implements a variety of on-line logistric regression classifiers using SGD-based algorithms.
+ * SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms
+ * that make it relatively easy to build high speed on-line learning algorithms for a variety
+ * of problems, notably including supervised learning for classification.</p>
+ *
+ * <p>The primary class of interest in the this package is
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which contains a
+ * number (typically 5) of sub-learners, each of which is given a different portion of the
+ * training data. Each of these sub-learners can then be evaluated on the data it was not
+ * trained on. This allows fully incremental learning while still getting cross-validated
+ * performance estimates.</p>
+ *
+ * <p>The CrossFoldLearner implements {@link org.apache.mahout.classifier.OnlineLearner}
+ * and thus expects to be fed input in the form
+ * of a target variable and a feature vector. The target variable is simply an integer in the
+ * half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner
+ * is constructed. The creation of feature vectors is facilitated by the classes that inherit
+ * from {@link org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder}.
+ * These classes currently implement a form of feature hashing with
+ * multiple probes to limit feature ambiguity.</p>
+ */
+package org.apache.mahout.classifier.sgd;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
new file mode 100644
index 0000000..cc05beb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
@@ -0,0 +1,391 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.codehaus.jackson.map.ObjectMapper;
+
+public abstract class AbstractCluster implements Cluster {
+
+ // cluster persistent state
+ private int id;
+
+ private long numObservations;
+
+ private long totalObservations;
+
+ private Vector center;
+
+ private Vector radius;
+
+ // the observation statistics
+ private double s0;
+
+ private Vector s1;
+
+ private Vector s2;
+
+ private static final ObjectMapper jxn = new ObjectMapper();
+
+ protected AbstractCluster() {}
+
+ protected AbstractCluster(Vector point, int id2) {
+ this.numObservations = (long) 0;
+ this.totalObservations = (long) 0;
+ this.center = point.clone();
+ this.radius = center.like();
+ this.s0 = (double) 0;
+ this.s1 = center.like();
+ this.s2 = center.like();
+ this.id = id2;
+ }
+
+ protected AbstractCluster(Vector center2, Vector radius2, int id2) {
+ this.numObservations = (long) 0;
+ this.totalObservations = (long) 0;
+ this.center = new RandomAccessSparseVector(center2);
+ this.radius = new RandomAccessSparseVector(radius2);
+ this.s0 = (double) 0;
+ this.s1 = center.like();
+ this.s2 = center.like();
+ this.id = id2;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
+ out.writeLong(getNumObservations());
+ out.writeLong(getTotalObservations());
+ VectorWritable.writeVector(out, getCenter());
+ VectorWritable.writeVector(out, getRadius());
+ out.writeDouble(s0);
+ VectorWritable.writeVector(out, s1);
+ VectorWritable.writeVector(out, s2);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.id = in.readInt();
+ this.setNumObservations(in.readLong());
+ this.setTotalObservations(in.readLong());
+ this.setCenter(VectorWritable.readVector(in));
+ this.setRadius(VectorWritable.readVector(in));
+ this.setS0(in.readDouble());
+ this.setS1(VectorWritable.readVector(in));
+ this.setS2(VectorWritable.readVector(in));
+ }
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ /**
+ * @param id
+ * the id to set
+ */
+ protected void setId(int id) {
+ this.id = id;
+ }
+
+ @Override
+ public long getNumObservations() {
+ return numObservations;
+ }
+
+ /**
+ * @param l
+ * the numPoints to set
+ */
+ protected void setNumObservations(long l) {
+ this.numObservations = l;
+ }
+
+ @Override
+ public long getTotalObservations() {
+ return totalObservations;
+ }
+
+ protected void setTotalObservations(long totalPoints) {
+ this.totalObservations = totalPoints;
+ }
+
+ @Override
+ public Vector getCenter() {
+ return center;
+ }
+
+ /**
+ * @param center
+ * the center to set
+ */
+ protected void setCenter(Vector center) {
+ this.center = center;
+ }
+
+ @Override
+ public Vector getRadius() {
+ return radius;
+ }
+
+ /**
+ * @param radius
+ * the radius to set
+ */
+ protected void setRadius(Vector radius) {
+ this.radius = radius;
+ }
+
+ /**
+ * @return the s0
+ */
+ protected double getS0() {
+ return s0;
+ }
+
+ protected void setS0(double s0) {
+ this.s0 = s0;
+ }
+
+ /**
+ * @return the s1
+ */
+ protected Vector getS1() {
+ return s1;
+ }
+
+ protected void setS1(Vector s1) {
+ this.s1 = s1;
+ }
+
+ /**
+ * @return the s2
+ */
+ protected Vector getS2() {
+ return s2;
+ }
+
+ protected void setS2(Vector s2) {
+ this.s2 = s2;
+ }
+
+ @Override
+ public void observe(Model<VectorWritable> x) {
+ AbstractCluster cl = (AbstractCluster) x;
+ setS0(getS0() + cl.getS0());
+ setS1(getS1().plus(cl.getS1()));
+ setS2(getS2().plus(cl.getS2()));
+ }
+
+ @Override
+ public void observe(VectorWritable x) {
+ observe(x.get());
+ }
+
+ @Override
+ public void observe(VectorWritable x, double weight) {
+ observe(x.get(), weight);
+ }
+
+ public void observe(Vector x, double weight) {
+ if (weight == 1.0) {
+ observe(x);
+ } else {
+ setS0(getS0() + weight);
+ Vector weightedX = x.times(weight);
+ if (getS1() == null) {
+ setS1(weightedX);
+ } else {
+ getS1().assign(weightedX, Functions.PLUS);
+ }
+ Vector x2 = x.times(x).times(weight);
+ if (getS2() == null) {
+ setS2(x2);
+ } else {
+ getS2().assign(x2, Functions.PLUS);
+ }
+ }
+ }
+
+ public void observe(Vector x) {
+ setS0(getS0() + 1);
+ if (getS1() == null) {
+ setS1(x.clone());
+ } else {
+ getS1().assign(x, Functions.PLUS);
+ }
+ Vector x2 = x.times(x);
+ if (getS2() == null) {
+ setS2(x2);
+ } else {
+ getS2().assign(x2, Functions.PLUS);
+ }
+ }
+
+
+ @Override
+ public void computeParameters() {
+ if (getS0() == 0) {
+ return;
+ }
+ setNumObservations((long) getS0());
+ setTotalObservations(getTotalObservations() + getNumObservations());
+ setCenter(getS1().divide(getS0()));
+ // compute the component stds
+ if (getS0() > 1) {
+ setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction()).divide(getS0()));
+ }
+ setS0(0);
+ setS1(center.like());
+ setS2(center.like());
+ }
+
+ @Override
+ public String asFormatString(String[] bindings) {
+ String fmtString = "";
+ try {
+ fmtString = jxn.writeValueAsString(asJson(bindings));
+ } catch (IOException e) {
+ log.error("Error writing JSON as String.", e);
+ }
+ return fmtString;
+ }
+
+ public Map<String,Object> asJson(String[] bindings) {
+ Map<String,Object> dict = new HashMap<>();
+ dict.put("identifier", getIdentifier());
+ dict.put("n", getNumObservations());
+ if (getCenter() != null) {
+ try {
+ dict.put("c", formatVectorAsJson(getCenter(), bindings));
+ } catch (IOException e) {
+ log.error("IOException: ", e);
+ }
+ }
+ if (getRadius() != null) {
+ try {
+ dict.put("r", formatVectorAsJson(getRadius(), bindings));
+ } catch (IOException e) {
+ log.error("IOException: ", e);
+ }
+ }
+ return dict;
+ }
+
+ public abstract String getIdentifier();
+
+ /**
+ * Compute the centroid by averaging the pointTotals
+ *
+ * @return the new centroid
+ */
+ public Vector computeCentroid() {
+ return getS0() == 0 ? getCenter() : getS1().divide(getS0());
+ }
+
+ /**
+ * Return a human-readable formatted string representation of the vector, not
+ * intended to be complete nor usable as an input/output representation
+ */
+ public static String formatVector(Vector v, String[] bindings) {
+ String fmtString = "";
+ try {
+ fmtString = jxn.writeValueAsString(formatVectorAsJson(v, bindings));
+ } catch (IOException e) {
+ log.error("Error writing JSON as String.", e);
+ }
+ return fmtString;
+ }
+
+ /**
+ * Create a List of HashMaps containing vector terms and weights
+ *
+ * @return List<Object>
+ */
+ public static List<Object> formatVectorAsJson(Vector v, String[] bindings) throws IOException {
+
+ boolean hasBindings = bindings != null;
+ boolean isSparse = !v.isDense() && v.getNumNondefaultElements() != v.size();
+
+ // we assume sequential access in the output
+ Vector provider = v.isSequentialAccess() ? v : new SequentialAccessSparseVector(v);
+
+ List<Object> terms = Lists.newLinkedList();
+ String term = "";
+
+ for (Element elem : provider.nonZeroes()) {
+
+ if (hasBindings && bindings.length >= elem.index() + 1 && bindings[elem.index()] != null) {
+ term = bindings[elem.index()];
+ } else if (hasBindings || isSparse) {
+ term = String.valueOf(elem.index());
+ }
+
+ Map<String, Object> term_entry = Maps.newHashMap();
+ double roundedWeight = (double) Math.round(elem.get() * 1000) / 1000;
+ if (hasBindings || isSparse) {
+ term_entry.put(term, roundedWeight);
+ terms.add(term_entry);
+ } else {
+ terms.add(roundedWeight);
+ }
+ }
+
+ return terms;
+ }
+
+ @Override
+ public boolean isConverged() {
+ // Convergence has no meaning yet, perhaps in subclasses
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/Cluster.java b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
new file mode 100644
index 0000000..07d6927
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
@@ -0,0 +1,90 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.parameters.Parametered;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.util.Map;
+
+/**
+ * Implementations of this interface have a printable representation and certain
+ * attributes that are common across all clustering implementations
+ *
+ */
+public interface Cluster extends Model<VectorWritable>, Parametered {
+
+ // default directory for initial clusters to prime iterative clustering
+ // algorithms
+ String INITIAL_CLUSTERS_DIR = "clusters-0";
+
+ // default directory for output of clusters per iteration
+ String CLUSTERS_DIR = "clusters-";
+
+ // default suffix for output of clusters for final iteration
+ String FINAL_ITERATION_SUFFIX = "-final";
+
+ /**
+ * Get the id of the Cluster
+ *
+ * @return a unique integer
+ */
+ int getId();
+
+ /**
+ * Get the "center" of the Cluster as a Vector
+ *
+ * @return a Vector
+ */
+ Vector getCenter();
+
+ /**
+ * Get the "radius" of the Cluster as a Vector. Usually the radius is the
+ * standard deviation expressed as a Vector of size equal to the center. Some
+ * clusters may return zero values if not appropriate.
+ *
+ * @return aVector
+ */
+ Vector getRadius();
+
+ /**
+ * Produce a custom, human-friendly, printable representation of the Cluster.
+ *
+ * @param bindings
+ * an optional String[] containing labels used to format the primary
+ * Vector/s of this implementation.
+ * @return a String
+ */
+ String asFormatString(String[] bindings);
+
+ /**
+ * Produce a JSON representation of the Cluster.
+ *
+ * @param bindings
+ * an optional String[] containing labels used to format the primary
+ * Vector/s of this implementation.
+ * @return a Map
+ */
+ Map<String,Object> asJson(String[] bindings);
+
+ /**
+ * @return if the receiver has converged, or false if that has no meaning for
+ * the implementation
+ */
+ boolean isConverged();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
new file mode 100644
index 0000000..421ffcf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
@@ -0,0 +1,305 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.neighborhood.Searcher;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+public final class ClusteringUtils {
+ private ClusteringUtils() {
+ }
+
+ /**
+ * Computes the summaries for the distances in each cluster.
+ * @param datapoints iterable of datapoints.
+ * @param centroids iterable of Centroids.
+ * @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose
+ * index is i.
+ */
+ public static List<OnlineSummarizer> summarizeClusterDistances(Iterable<? extends Vector> datapoints,
+ Iterable<? extends Vector> centroids,
+ DistanceMeasure distanceMeasure) {
+ UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
+ searcher.addAll(centroids);
+ List<OnlineSummarizer> summarizers = Lists.newArrayList();
+ if (searcher.size() == 0) {
+ return summarizers;
+ }
+ for (int i = 0; i < searcher.size(); ++i) {
+ summarizers.add(new OnlineSummarizer());
+ }
+ for (Vector v : datapoints) {
+ Centroid closest = (Centroid)searcher.search(v, 1).get(0).getValue();
+ OnlineSummarizer summarizer = summarizers.get(closest.getIndex());
+ summarizer.add(distanceMeasure.distance(v, closest));
+ }
+ return summarizers;
+ }
+
+ /**
+ * Adds up the distances from each point to its closest cluster and returns the sum.
+ * @param datapoints iterable of datapoints.
+ * @param centroids iterable of Centroids.
+ * @return the total cost described above.
+ */
+ public static double totalClusterCost(Iterable<? extends Vector> datapoints, Iterable<? extends Vector> centroids) {
+ DistanceMeasure distanceMeasure = new EuclideanDistanceMeasure();
+ UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
+ searcher.addAll(centroids);
+ return totalClusterCost(datapoints, searcher);
+ }
+
+ /**
+ * Adds up the distances from each point to its closest cluster and returns the sum.
+ * @param datapoints iterable of datapoints.
+ * @param centroids searcher of Centroids.
+ * @return the total cost described above.
+ */
+ public static double totalClusterCost(Iterable<? extends Vector> datapoints, Searcher centroids) {
+ double totalCost = 0;
+ for (Vector vector : datapoints) {
+ totalCost += centroids.searchFirst(vector, false).getWeight();
+ }
+ return totalCost;
+ }
+
+ /**
+ * Estimates the distance cutoff. In StreamingKMeans, the distance between two vectors divided
+ * by this value is used as a probability threshold when deciding whether to form a new cluster
+ * or not.
+ * Small values (comparable to the minimum distance between two points) are preferred as they
+ * guarantee with high likelihood that all but very close points are put in separate clusters
+ * initially. The clusters themselves are actually collapsed periodically when their number goes
+ * over the maximum number of clusters and the distanceCutoff is increased.
+ * So, the returned value is only an initial estimate.
+ * @param data the datapoints whose distance is to be estimated.
+ * @param distanceMeasure the distance measure used to compute the distance between two points.
+ * @return the minimum distance between the first sampleLimit points
+ * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
+ */
+ public static double estimateDistanceCutoff(List<? extends Vector> data, DistanceMeasure distanceMeasure) {
+ BruteSearch searcher = new BruteSearch(distanceMeasure);
+ searcher.addAll(data);
+ double minDistance = Double.POSITIVE_INFINITY;
+ for (Vector vector : data) {
+ double closest = searcher.searchFirst(vector, true).getWeight();
+ if (minDistance > 0 && closest < minDistance) {
+ minDistance = closest;
+ }
+ searcher.add(vector);
+ }
+ return minDistance;
+ }
+
+ public static <T extends Vector> double estimateDistanceCutoff(
+ Iterable<T> data, DistanceMeasure distanceMeasure, int sampleLimit) {
+ return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure);
+ }
+
+ /**
+ * Computes the Davies-Bouldin Index for a given clustering.
+ * See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
+ * @param centroids list of centroids
+ * @param distanceMeasure distance measure for inter-cluster distances
+ * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
+ * @return the Davies-Bouldin Index
+ */
+ public static double daviesBouldinIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
+ List<OnlineSummarizer> clusterDistanceSummaries) {
+ Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
+ "Number of centroids and cluster summaries differ.");
+ int n = centroids.size();
+ double totalDBIndex = 0;
+ // The inner loop shouldn't be reduced for j = i + 1 to n because the computation of the Davies-Bouldin
+ // index is not really symmetric.
+ // For a given cluster i, we look for a cluster j that maximizes the ratio of the sum of average distances
+ // from points in cluster i to its center and and points in cluster j to its center to the distance between
+ // cluster i and cluster j.
+ // The maximization is the key issue, as the cluster that maximizes this ratio might be j for i but is NOT
+ // NECESSARILY i for j.
+ for (int i = 0; i < n; ++i) {
+ double averageDistanceI = clusterDistanceSummaries.get(i).getMean();
+ double maxDBIndex = 0;
+ for (int j = 0; j < n; ++j) {
+ if (i != j) {
+ double dbIndex = (averageDistanceI + clusterDistanceSummaries.get(j).getMean())
+ / distanceMeasure.distance(centroids.get(i), centroids.get(j));
+ if (dbIndex > maxDBIndex) {
+ maxDBIndex = dbIndex;
+ }
+ }
+ }
+ totalDBIndex += maxDBIndex;
+ }
+ return totalDBIndex / n;
+ }
+
+ /**
+ * Computes the Dunn Index of a given clustering. See http://en.wikipedia.org/wiki/Dunn_index
+ * @param centroids list of centroids
+ * @param distanceMeasure distance measure to compute inter-centroid distance with
+ * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
+ * @return the Dunn Index
+ */
+ public static double dunnIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
+ List<OnlineSummarizer> clusterDistanceSummaries) {
+ Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
+ "Number of centroids and cluster summaries differ.");
+ int n = centroids.size();
+ // Intra-cluster distances will come from the OnlineSummarizer, and will be the median distance (noting that
+ // the median for just one value is that value).
+ // A variety of metrics can be used for the intra-cluster distance including max distance between two points,
+ // mean distance, etc. Median distance was chosen as this is more robust to outliers and characterizes the
+ // distribution of distances (from a point to the center) better.
+ double maxIntraClusterDistance = 0;
+ for (OnlineSummarizer summarizer : clusterDistanceSummaries) {
+ if (summarizer.getCount() > 0) {
+ double intraClusterDistance;
+ if (summarizer.getCount() == 1) {
+ intraClusterDistance = summarizer.getMean();
+ } else {
+ intraClusterDistance = summarizer.getMedian();
+ }
+ if (maxIntraClusterDistance < intraClusterDistance) {
+ maxIntraClusterDistance = intraClusterDistance;
+ }
+ }
+ }
+ double minDunnIndex = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < n; ++i) {
+ // Distances are symmetric, so d(i, j) = d(j, i).
+ for (int j = i + 1; j < n; ++j) {
+ double dunnIndex = distanceMeasure.distance(centroids.get(i), centroids.get(j));
+ if (minDunnIndex > dunnIndex) {
+ minDunnIndex = dunnIndex;
+ }
+ }
+ }
+ return minDunnIndex / maxIntraClusterDistance;
+ }
+
+ public static double choose2(double n) {
+ return n * (n - 1) / 2;
+ }
+
+ /**
+ * Creates a confusion matrix by searching for the closest cluster of both the row clustering and column clustering
+ * of a point and adding its weight to that cell of the matrix.
+ * It doesn't matter which clustering is the row clustering and which is the column clustering. If they're
+ * interchanged, the resulting matrix is the transpose of the original one.
+ * @param rowCentroids clustering one
+ * @param columnCentroids clustering two
+ * @param datapoints datapoints whose closest cluster we need to find
+ * @param distanceMeasure distance measure to use
+ * @return the confusion matrix
+ */
+ public static Matrix getConfusionMatrix(List<? extends Vector> rowCentroids, List<? extends Vector> columnCentroids,
+ Iterable<? extends Vector> datapoints, DistanceMeasure distanceMeasure) {
+ Searcher rowSearcher = new BruteSearch(distanceMeasure);
+ rowSearcher.addAll(rowCentroids);
+ Searcher columnSearcher = new BruteSearch(distanceMeasure);
+ columnSearcher.addAll(columnCentroids);
+
+ int numRows = rowCentroids.size();
+ int numCols = columnCentroids.size();
+ Matrix confusionMatrix = new DenseMatrix(numRows, numCols);
+
+ for (Vector vector : datapoints) {
+ WeightedThing<Vector> closestRowCentroid = rowSearcher.search(vector, 1).get(0);
+ WeightedThing<Vector> closestColumnCentroid = columnSearcher.search(vector, 1).get(0);
+ int row = ((Centroid) closestRowCentroid.getValue()).getIndex();
+ int column = ((Centroid) closestColumnCentroid.getValue()).getIndex();
+ double vectorWeight;
+ if (vector instanceof WeightedVector) {
+ vectorWeight = ((WeightedVector) vector).getWeight();
+ } else {
+ vectorWeight = 1;
+ }
+ confusionMatrix.set(row, column, confusionMatrix.get(row, column) + vectorWeight);
+ }
+
+ return confusionMatrix;
+ }
+
+ /**
+ * Computes the Adjusted Rand Index for a given confusion matrix.
+ * @param confusionMatrix confusion matrix; not to be confused with the more restrictive ConfusionMatrix class
+ * @return the Adjusted Rand Index
+ */
+ public static double getAdjustedRandIndex(Matrix confusionMatrix) {
+ int numRows = confusionMatrix.numRows();
+ int numCols = confusionMatrix.numCols();
+ double rowChoiceSum = 0;
+ double columnChoiceSum = 0;
+ double totalChoiceSum = 0;
+ double total = 0;
+ for (int i = 0; i < numRows; ++i) {
+ double rowSum = 0;
+ for (int j = 0; j < numCols; ++j) {
+ rowSum += confusionMatrix.get(i, j);
+ totalChoiceSum += choose2(confusionMatrix.get(i, j));
+ }
+ total += rowSum;
+ rowChoiceSum += choose2(rowSum);
+ }
+ for (int j = 0; j < numCols; ++j) {
+ double columnSum = 0;
+ for (int i = 0; i < numRows; ++i) {
+ columnSum += confusionMatrix.get(i, j);
+ }
+ columnChoiceSum += choose2(columnSum);
+ }
+ double rowColumnChoiceSumDivTotal = rowChoiceSum * columnChoiceSum / choose2(total);
+ return (totalChoiceSum - rowColumnChoiceSumDivTotal)
+ / ((rowChoiceSum + columnChoiceSum) / 2 - rowColumnChoiceSumDivTotal);
+ }
+
+ /**
+ * Computes the total weight of the points in the given Vector iterable.
+ * @param data iterable of points
+ * @return total weight
+ */
+ public static double totalWeight(Iterable<? extends Vector> data) {
+ double sum = 0;
+ for (Vector row : data) {
+ Preconditions.checkNotNull(row);
+ if (row instanceof WeightedVector) {
+ sum += ((WeightedVector)row).getWeight();
+ } else {
+ sum++;
+ }
+ }
+ return sum;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
new file mode 100644
index 0000000..c25e039
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+
+public interface GaussianAccumulator {
+
+ /**
+ * @return the number of observations
+ */
+ double getN();
+
+ /**
+ * @return the mean of the observations
+ */
+ Vector getMean();
+
+ /**
+ * @return the std of the observations
+ */
+ Vector getStd();
+
+ /**
+ * @return the average of the vector std elements
+ */
+ double getAverageStd();
+
+ /**
+ * @return the variance of the observations
+ */
+ Vector getVariance();
+
+ /**
+ * Observe the vector
+ *
+ * @param x a Vector
+ * @param weight the double observation weight (usually 1.0)
+ */
+ void observe(Vector x, double weight);
+
+ /**
+ * Compute the mean, variance and standard deviation
+ */
+ void compute();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Model.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/Model.java b/mr/src/main/java/org/apache/mahout/clustering/Model.java
new file mode 100644
index 0000000..79dab30
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/Model.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * A model is a probability distribution over observed data points and allows
+ * the probability of any data point to be computed. All Models have a
+ * persistent representation and extend
+ * WritablesampleFromPosterior(Model<VectorWritable>[])
+ */
+public interface Model<O> extends Writable {
+
+ /**
+ * Return the probability that the observation is described by this model
+ *
+ * @param x
+ * an Observation from the posterior
+ * @return the probability that x is in the receiver
+ */
+ double pdf(O x);
+
+ /**
+ * Observe the given observation, retaining information about it
+ *
+ * @param x
+ * an Observation from the posterior
+ */
+ void observe(O x);
+
+ /**
+ * Observe the given observation, retaining information about it
+ *
+ * @param x
+ * an Observation from the posterior
+ * @param weight
+ * a double weighting factor
+ */
+ void observe(O x, double weight);
+
+ /**
+ * Observe the given model, retaining information about its observations
+ *
+ * @param x
+ * a Model<0>
+ */
+ void observe(Model<O> x);
+
+ /**
+ * Compute a new set of posterior parameters based upon the Observations that
+ * have been observed since my creation
+ */
+ void computeParameters();
+
+ /**
+ * Return the number of observations that this model has seen since its
+ * parameters were last computed
+ *
+ * @return a long
+ */
+ long getNumObservations();
+
+ /**
+ * Return the number of observations that this model has seen over its
+ * lifetime
+ *
+ * @return a long
+ */
+ long getTotalObservations();
+
+ /**
+ * @return a sample of my posterior model
+ */
+ Model<VectorWritable> sampleFromPosterior();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
new file mode 100644
index 0000000..d77bf40
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+/** A model distribution allows us to sample a model from its prior distribution. */
+public interface ModelDistribution<O> {
+
+ /**
+ * Return a list of models sampled from the prior
+ *
+ * @param howMany
+ * the int number of models to return
+ * @return a Model<Observation>[] representing what is known apriori
+ */
+ Model<O>[] sampleFromPrior(int howMany);
+
+ /**
+ * Return a list of models sampled from the posterior
+ *
+ * @param posterior
+ * the Model<Observation>[] after observations
+ * @return a Model<Observation>[] representing what is known apriori
+ */
+ Model<O>[] sampleFromPosterior(Model<O>[] posterior);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
new file mode 100644
index 0000000..b76e00f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+/**
+ * An online Gaussian statistics accumulator based upon Knuth (who cites Welford) which is declared to be
+ * numerically-stable. See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ */
+public class OnlineGaussianAccumulator implements GaussianAccumulator {
+
+ private double sumWeight;
+ private Vector mean;
+ private Vector s;
+ private Vector variance;
+
+ @Override
+ public double getN() {
+ return sumWeight;
+ }
+
+ @Override
+ public Vector getMean() {
+ return mean;
+ }
+
+ @Override
+ public Vector getStd() {
+ return variance.clone().assign(new SquareRootFunction());
+ }
+
+ /* from Wikipedia: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ *
+ * Weighted incremental algorithm
+ *
+ * def weighted_incremental_variance(dataWeightPairs):
+ * mean = 0
+ * S = 0
+ * sumweight = 0
+ * for x, weight in dataWeightPairs: # Alternately "for x in zip(data, weight):"
+ * temp = weight + sumweight
+ * Q = x - mean
+ * R = Q * weight / temp
+ * S = S + sumweight * Q * R
+ * mean = mean + R
+ * sumweight = temp
+ * Variance = S / (sumweight-1) # if sample is the population, omit -1
+ * return Variance
+ */
+ @Override
+ public void observe(Vector x, double weight) {
+ double temp = weight + sumWeight;
+ Vector q;
+ if (mean == null) {
+ mean = x.like();
+ q = x.clone();
+ } else {
+ q = x.minus(mean);
+ }
+ Vector r = q.times(weight).divide(temp);
+ if (s == null) {
+ s = q.times(sumWeight).times(r);
+ } else {
+ s = s.plus(q.times(sumWeight).times(r));
+ }
+ mean = mean.plus(r);
+ sumWeight = temp;
+ variance = s.divide(sumWeight - 1); // # if sample is the population, omit -1
+ }
+
+ @Override
+ public void compute() {
+ // nothing to do here!
+ }
+
+ @Override
+ public double getAverageStd() {
+ if (sumWeight == 0.0) {
+ return 0.0;
+ } else {
+ Vector std = getStd();
+ return std.zSum() / std.size();
+ }
+ }
+
+ @Override
+ public Vector getVariance() {
+ return variance;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
new file mode 100644
index 0000000..138e830
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+/**
+ * An online Gaussian accumulator that uses a running power sums approach as reported
+ * on http://en.wikipedia.org/wiki/Standard_deviation
+ * Suffers from overflow, underflow and roundoff error but has minimal observe-time overhead
+ */
+public class RunningSumsGaussianAccumulator implements GaussianAccumulator {
+
+ private double s0;
+ private Vector s1;
+ private Vector s2;
+ private Vector mean;
+ private Vector std;
+
+ @Override
+ public double getN() {
+ return s0;
+ }
+
+ @Override
+ public Vector getMean() {
+ return mean;
+ }
+
+ @Override
+ public Vector getStd() {
+ return std;
+ }
+
+ @Override
+ public double getAverageStd() {
+ if (s0 == 0.0) {
+ return 0.0;
+ } else {
+ return std.zSum() / std.size();
+ }
+ }
+
+ @Override
+ public Vector getVariance() {
+ return std.times(std);
+ }
+
+ @Override
+ public void observe(Vector x, double weight) {
+ s0 += weight;
+ Vector weightedX = x.times(weight);
+ if (s1 == null) {
+ s1 = weightedX;
+ } else {
+ s1.assign(weightedX, Functions.PLUS);
+ }
+ Vector x2 = x.times(x).times(weight);
+ if (s2 == null) {
+ s2 = x2;
+ } else {
+ s2.assign(x2, Functions.PLUS);
+ }
+ }
+
+ @Override
+ public void compute() {
+ if (s0 != 0.0) {
+ mean = s1.divide(s0);
+ std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
new file mode 100644
index 0000000..ef43e1b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+public final class UncommonDistributions {
+
+ private static final RandomWrapper RANDOM = RandomUtils.getRandom();
+
+ private UncommonDistributions() {}
+
+ // =============== start of BSD licensed code. See LICENSE.txt
+ /**
+ * Returns a double sampled according to this distribution. Uniformly fast for all k > 0. (Reference:
+ * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Uses
+ * Cheng's rejection algorithm (GB) for k>=1, rejection from Weibull distribution for 0 < k < 1.
+ */
+ public static double rGamma(double k, double lambda) {
+ boolean accept = false;
+ if (k >= 1.0) {
+ // Cheng's algorithm
+ double b = k - Math.log(4.0);
+ double c = k + Math.sqrt(2.0 * k - 1.0);
+ double lam = Math.sqrt(2.0 * k - 1.0);
+ double cheng = 1.0 + Math.log(4.5);
+ double x;
+ do {
+ double u = RANDOM.nextDouble();
+ double v = RANDOM.nextDouble();
+ double y = 1.0 / lam * Math.log(v / (1.0 - v));
+ x = k * Math.exp(y);
+ double z = u * v * v;
+ double r = b + c * y - x;
+ if (r >= 4.5 * z - cheng || r >= Math.log(z)) {
+ accept = true;
+ }
+ } while (!accept);
+ return x / lambda;
+ } else {
+ // Weibull algorithm
+ double c = 1.0 / k;
+ double d = (1.0 - k) * Math.pow(k, k / (1.0 - k));
+ double x;
+ do {
+ double u = RANDOM.nextDouble();
+ double v = RANDOM.nextDouble();
+ double z = -Math.log(u);
+ double e = -Math.log(v);
+ x = Math.pow(z, c);
+ if (z + e >= d + x) {
+ accept = true;
+ }
+ } while (!accept);
+ return x / lambda;
+ }
+ }
+
+ // ============= end of BSD licensed code
+
+ /**
+ * Returns a random sample from a beta distribution with the given shapes
+ *
+ * @param shape1
+ * a double representing shape1
+ * @param shape2
+ * a double representing shape2
+ * @return a Vector of samples
+ */
+ public static double rBeta(double shape1, double shape2) {
+ double gam1 = rGamma(shape1, 1.0);
+ double gam2 = rGamma(shape2, 1.0);
+ return gam1 / (gam1 + gam2);
+
+ }
+
+ /**
+ * Return a random value from a normal distribution with the given mean and standard deviation
+ *
+ * @param mean
+ * a double mean value
+ * @param sd
+ * a double standard deviation
+ * @return a double sample
+ */
+ public static double rNorm(double mean, double sd) {
+ RealDistribution dist = new NormalDistribution(RANDOM.getRandomGenerator(),
+ mean,
+ sd,
+ NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
+ return dist.sample();
+ }
+
+ /**
+ * Returns an integer sampled according to this distribution. Takes time proportional to np + 1. (Reference:
+ * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Second
+ * time-waiting algorithm.
+ */
+ public static int rBinomial(int n, double p) {
+ if (p >= 1.0) {
+ return n; // needed to avoid infinite loops and negative results
+ }
+ double q = -Math.log1p(-p);
+ double sum = 0.0;
+ int x = 0;
+ while (sum <= q) {
+ double u = RANDOM.nextDouble();
+ double e = -Math.log(u);
+ sum += e / (n - x);
+ x++;
+ }
+ if (x == 0) {
+ return 0;
+ }
+ return x - 1;
+ }
+
+}