You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:27 UTC
[25/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
new file mode 100644
index 0000000..0e7ee96
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
@@ -0,0 +1,417 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.commandline;
+
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.clustering.kernel.TriangularKernelProfile;
+
+
+public final class DefaultOptionCreator {
+
+ public static final String CLUSTERING_OPTION = "clustering";
+
+ public static final String CLUSTERS_IN_OPTION = "clusters";
+
+ public static final String CONVERGENCE_DELTA_OPTION = "convergenceDelta";
+
+ public static final String DISTANCE_MEASURE_OPTION = "distanceMeasure";
+
+ public static final String EMIT_MOST_LIKELY_OPTION = "emitMostLikely";
+
+ public static final String INPUT_OPTION = "input";
+
+ public static final String MAX_ITERATIONS_OPTION = "maxIter";
+
+ public static final String MAX_REDUCERS_OPTION = "maxRed";
+
+ public static final String METHOD_OPTION = "method";
+
+ public static final String NUM_CLUSTERS_OPTION = "numClusters";
+
+ public static final String OUTPUT_OPTION = "output";
+
+ public static final String OVERWRITE_OPTION = "overwrite";
+
+ public static final String T1_OPTION = "t1";
+
+ public static final String T2_OPTION = "t2";
+
+ public static final String T3_OPTION = "t3";
+
+ public static final String T4_OPTION = "t4";
+
+ public static final String OUTLIER_THRESHOLD = "outlierThreshold";
+
+ public static final String CLUSTER_FILTER_OPTION = "clusterFilter";
+
+ public static final String THRESHOLD_OPTION = "threshold";
+
+ public static final String SEQUENTIAL_METHOD = "sequential";
+
+ public static final String MAPREDUCE_METHOD = "mapreduce";
+
+ public static final String KERNEL_PROFILE_OPTION = "kernelProfile";
+
+ public static final String ANALYZER_NAME_OPTION = "analyzerName";
+
+ public static final String RANDOM_SEED = "randomSeed";
+
+ private DefaultOptionCreator() {}
+
+ /**
+ * Returns a default command line option for help. Used by all clustering jobs
+ * and many others
+ * */
+ public static Option helpOption() {
+ return new DefaultOptionBuilder().withLongName("help")
+ .withDescription("Print out help").withShortName("h").create();
+ }
+
+ /**
+ * Returns a default command line option for input directory specification.
+ * Used by all clustering jobs plus others
+ */
+ public static DefaultOptionBuilder inputOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(INPUT_OPTION)
+ .withRequired(false)
+ .withShortName("i")
+ .withArgument(
+ new ArgumentBuilder().withName(INPUT_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("Path to job input directory.");
+ }
+
+ /**
+ * Returns a default command line option for clusters input directory
+ * specification. Used by FuzzyKmeans, Kmeans
+ */
+ public static DefaultOptionBuilder clustersInOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTERS_IN_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(CLUSTERS_IN_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription(
+ "The path to the initial clusters directory. Must be a SequenceFile of some type of Cluster")
+ .withShortName("c");
+ }
+
+ /**
+ * Returns a default command line option for output directory specification.
+ * Used by all clustering jobs plus others
+ */
+ public static DefaultOptionBuilder outputOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OUTPUT_OPTION)
+ .withRequired(false)
+ .withShortName("o")
+ .withArgument(
+ new ArgumentBuilder().withName(OUTPUT_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("The directory pathname for output.");
+ }
+
+ /**
+ * Returns a default command line option for output directory overwriting.
+ * Used by all clustering jobs
+ */
+ public static DefaultOptionBuilder overwriteOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OVERWRITE_OPTION)
+ .withRequired(false)
+ .withDescription(
+ "If present, overwrite the output directory before running job")
+ .withShortName("ow");
+ }
+
+ /**
+ * Returns a default command line option for specification of distance measure
+ * class to use. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift
+ */
+ public static DefaultOptionBuilder distanceMeasureOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(DISTANCE_MEASURE_OPTION)
+ .withRequired(false)
+ .withShortName("dm")
+ .withArgument(
+ new ArgumentBuilder().withName(DISTANCE_MEASURE_OPTION)
+ .withDefault(SquaredEuclideanDistanceMeasure.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The classname of the DistanceMeasure. Default is SquaredEuclidean");
+ }
+
+ /**
+ * Returns a default command line option for specification of sequential or
+ * parallel operation. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift,
+ * Dirichlet
+ */
+ public static DefaultOptionBuilder methodOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(METHOD_OPTION)
+ .withRequired(false)
+ .withShortName("xm")
+ .withArgument(
+ new ArgumentBuilder().withName(METHOD_OPTION)
+ .withDefault(MAPREDUCE_METHOD).withMinimum(1).withMaximum(1)
+ .create())
+ .withDescription(
+ "The execution method to use: sequential or mapreduce. Default is mapreduce");
+ }
+
+ /**
+ * Returns a default command line option for specification of T1. Used by
+ * Canopy, MeanShift
+ */
+ public static DefaultOptionBuilder t1Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T1_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(T1_OPTION).withMinimum(1)
+ .withMaximum(1).create()).withDescription("T1 threshold value")
+ .withShortName(T1_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T2. Used by
+ * Canopy, MeanShift
+ */
+ public static DefaultOptionBuilder t2Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T2_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(T2_OPTION).withMinimum(1)
+ .withMaximum(1).create()).withDescription("T2 threshold value")
+ .withShortName(T2_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T3 (Reducer T1).
+ * Used by Canopy
+ */
+ public static DefaultOptionBuilder t3Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T3_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(T3_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("T3 (Reducer T1) threshold value")
+ .withShortName(T3_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T4 (Reducer T2).
+ * Used by Canopy
+ */
+ public static DefaultOptionBuilder t4Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T4_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(T4_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("T4 (Reducer T2) threshold value")
+ .withShortName(T4_OPTION);
+ }
+
+ /**
+ * @return a DefaultOptionBuilder for the clusterFilter option
+ */
+ public static DefaultOptionBuilder clusterFilterOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTER_FILTER_OPTION)
+ .withShortName("cf")
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(CLUSTER_FILTER_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("Cluster filter suppresses small canopies from mapper")
+ .withShortName(CLUSTER_FILTER_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of max number of
+ * iterations. Used by Dirichlet, FuzzyKmeans, Kmeans, LDA
+ */
+ public static DefaultOptionBuilder maxIterationsOption() {
+ // default value used by LDA which overrides withRequired(false)
+ return new DefaultOptionBuilder()
+ .withLongName(MAX_ITERATIONS_OPTION)
+ .withRequired(true)
+ .withShortName("x")
+ .withArgument(
+ new ArgumentBuilder().withName(MAX_ITERATIONS_OPTION)
+ .withDefault("-1").withMinimum(1).withMaximum(1).create())
+ .withDescription("The maximum number of iterations.");
+ }
+
+ /**
+ * Returns a default command line option for specification of numbers of
+ * clusters to create. Used by Dirichlet, FuzzyKmeans, Kmeans
+ */
+ public static DefaultOptionBuilder numClustersOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(NUM_CLUSTERS_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName("k").withMinimum(1).withMaximum(1)
+ .create()).withDescription("The number of clusters to create")
+ .withShortName("k");
+ }
+
+ public static DefaultOptionBuilder useSetRandomSeedOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(RANDOM_SEED)
+ .withRequired(false)
+ .withArgument(new ArgumentBuilder().withName(RANDOM_SEED).create())
+ .withDescription("Seed to initaize Random Number Generator with")
+ .withShortName("rs");
+ }
+
+ /**
+ * Returns a default command line option for convergence delta specification.
+ * Used by FuzzyKmeans, Kmeans, MeanShift
+ */
+ public static DefaultOptionBuilder convergenceOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CONVERGENCE_DELTA_OPTION)
+ .withRequired(false)
+ .withShortName("cd")
+ .withArgument(
+ new ArgumentBuilder().withName(CONVERGENCE_DELTA_OPTION)
+ .withDefault("0.5").withMinimum(1).withMaximum(1).create())
+ .withDescription("The convergence delta value. Default is 0.5");
+ }
+
+ /**
+ * Returns a default command line option for specifying the max number of
+ * reducers. Used by Dirichlet, FuzzyKmeans, Kmeans and LDA
+ *
+ * @deprecated
+ */
+ @Deprecated
+ public static DefaultOptionBuilder numReducersOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(MAX_REDUCERS_OPTION)
+ .withRequired(false)
+ .withShortName("r")
+ .withArgument(
+ new ArgumentBuilder().withName(MAX_REDUCERS_OPTION)
+ .withDefault("2").withMinimum(1).withMaximum(1).create())
+ .withDescription("The number of reduce tasks. Defaults to 2");
+ }
+
+ /**
+ * Returns a default command line option for clustering specification. Used by
+ * all clustering except LDA
+ */
+ public static DefaultOptionBuilder clusteringOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTERING_OPTION)
+ .withRequired(false)
+ .withDescription(
+ "If present, run clustering after the iterations have taken place")
+ .withShortName("cl");
+ }
+
+ /**
+ * Returns a default command line option for specifying a Lucene analyzer class
+ * @return {@link DefaultOptionBuilder}
+ */
+ public static DefaultOptionBuilder analyzerOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(ANALYZER_NAME_OPTION)
+ .withRequired(false)
+ .withDescription("If present, the name of a Lucene analyzer class to use")
+ .withArgument(new ArgumentBuilder().withName(ANALYZER_NAME_OPTION).withDefault(StandardAnalyzer.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withShortName("an");
+ }
+
+
+ /**
+ * Returns a default command line option for specifying the emitMostLikely
+ * flag. Used by Dirichlet and FuzzyKmeans
+ */
+ public static DefaultOptionBuilder emitMostLikelyOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(EMIT_MOST_LIKELY_OPTION)
+ .withRequired(false)
+ .withShortName("e")
+ .withArgument(
+ new ArgumentBuilder().withName(EMIT_MOST_LIKELY_OPTION)
+ .withDefault("true").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "True if clustering should emit the most likely point only, "
+ + "false for threshold clustering. Default is true");
+ }
+
+ /**
+ * Returns a default command line option for specifying the clustering
+ * threshold value. Used by Dirichlet and FuzzyKmeans
+ */
+ public static DefaultOptionBuilder thresholdOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(THRESHOLD_OPTION)
+ .withRequired(false)
+ .withShortName("t")
+ .withArgument(
+ new ArgumentBuilder().withName(THRESHOLD_OPTION).withDefault("0")
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The pdf threshold used for cluster determination. Default is 0");
+ }
+
+ public static DefaultOptionBuilder kernelProfileOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(KERNEL_PROFILE_OPTION)
+ .withRequired(false)
+ .withShortName("kp")
+ .withArgument(
+ new ArgumentBuilder()
+ .withName(KERNEL_PROFILE_OPTION)
+ .withDefault(TriangularKernelProfile.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The classname of the IKernelProfile. Default is TriangularKernelProfile");
+ }
+
+ /**
+ * Returns a default command line option for specification of OUTLIER THRESHOLD value. Used for
+ * Cluster Classification.
+ */
+ public static DefaultOptionBuilder outlierThresholdOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OUTLIER_THRESHOLD)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(OUTLIER_THRESHOLD).withMinimum(1)
+ .withMaximum(1).create()).withDescription("Outlier threshold value")
+ .withShortName(OUTLIER_THRESHOLD);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
new file mode 100644
index 0000000..61aa9a5
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * This class implements a "Chebyshev distance" metric by finding the maximum difference
+ * between each coordinate. Also 'chessboard distance' due to the moves a king can make.
+ */
+public class ChebyshevDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return v1.aggregate(v2, Functions.MAX_ABS, Functions.MINUS);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
new file mode 100644
index 0000000..37265eb
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+
+/**
+ * This class implements a cosine distance metric by dividing the dot product of two vectors by the product of their
+ * lengths. That gives the cosine of the angle between the two vectors. To convert this to a usable distance,
+ * 1-cos(angle) is what is actually returned.
+ */
+public class CosineDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ public static double distance(double[] p1, double[] p2) {
+ double dotProduct = 0.0;
+ double lengthSquaredp1 = 0.0;
+ double lengthSquaredp2 = 0.0;
+ for (int i = 0; i < p1.length; i++) {
+ lengthSquaredp1 += p1[i] * p1[i];
+ lengthSquaredp2 += p2[i] * p2[i];
+ dotProduct += p1[i] * p2[i];
+ }
+ double denominator = Math.sqrt(lengthSquaredp1) * Math.sqrt(lengthSquaredp2);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ double lengthSquaredv1 = v1.getLengthSquared();
+ double lengthSquaredv2 = v2.getLengthSquared();
+
+ double dotProduct = v2.dot(v1);
+ double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+
+ double lengthSquaredv = v.getLengthSquared();
+
+ double dotProduct = v.dot(centroid);
+ double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
new file mode 100644
index 0000000..696e79c
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.parameters.Parametered;
+import org.apache.mahout.math.Vector;
+
+/** This interface is used for objects which can determine a distance metric between two points */
+public interface DistanceMeasure extends Parametered {
+
+ /**
+ * Returns the distance metric applied to the arguments
+ *
+ * @param v1
+ * a Vector defining a multidimensional point in some feature space
+ * @param v2
+ * a Vector defining a multidimensional point in some feature space
+ * @return a scalar doubles of the distance
+ */
+ double distance(Vector v1, Vector v2);
+
+ /**
+ * Optimized version of distance metric for sparse vectors. This distance computation requires operations
+ * proportional to the number of non-zero elements in the vector instead of the cardinality of the vector.
+ *
+ * @param centroidLengthSquare
+ * Square of the length of centroid
+ * @param centroid
+ * Centroid vector
+ */
+ double distance(double centroidLengthSquare, Vector centroid, Vector v);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
new file mode 100644
index 0000000..665678d
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * This class implements a Euclidean distance metric by summing the square root of the squared differences
+ * between each coordinate.
+ * <p/>
+ * If you don't care about the true distance and only need the values for comparison, then the base class,
+ * {@link SquaredEuclideanDistanceMeasure}, will be faster since it doesn't do the actual square root of the
+ * squared differences.
+ */
+public class EuclideanDistanceMeasure extends SquaredEuclideanDistanceMeasure {
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return Math.sqrt(super.distance(v1, v2));
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return Math.sqrt(super.distance(centroidLengthSquare, centroid, v));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
new file mode 100644
index 0000000..17ee714
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
@@ -0,0 +1,197 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.io.DataInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.parameters.ClassParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.common.parameters.PathParameter;
+import org.apache.mahout.math.Algebra;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+//See http://en.wikipedia.org/wiki/Mahalanobis_distance for details
+public class MahalanobisDistanceMeasure implements DistanceMeasure {
+
+ private Matrix inverseCovarianceMatrix;
+ private Vector meanVector;
+
+ private ClassParameter vectorClass;
+ private ClassParameter matrixClass;
+ private List<Parameter<?>> parameters;
+ private Parameter<Path> inverseCovarianceFile;
+ private Parameter<Path> meanVectorFile;
+
+ /*public MahalanobisDistanceMeasure(Vector meanVector,Matrix inputMatrix, boolean inversionNeeded)
+ {
+ this.meanVector=meanVector;
+ if (inversionNeeded)
+ setCovarianceMatrix(inputMatrix);
+ else
+ setInverseCovarianceMatrix(inputMatrix);
+ }*/
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ try {
+ if (inverseCovarianceFile.get() != null) {
+ FileSystem fs = FileSystem.get(inverseCovarianceFile.get().toUri(), jobConf);
+ MatrixWritable inverseCovarianceMatrix =
+ ClassUtils.instantiateAs((Class<? extends MatrixWritable>) matrixClass.get(), MatrixWritable.class);
+ if (!fs.exists(inverseCovarianceFile.get())) {
+ throw new FileNotFoundException(inverseCovarianceFile.get().toString());
+ }
+ try (DataInputStream in = fs.open(inverseCovarianceFile.get())){
+ inverseCovarianceMatrix.readFields(in);
+ }
+ this.inverseCovarianceMatrix = inverseCovarianceMatrix.get();
+ Preconditions.checkArgument(this.inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ }
+
+ if (meanVectorFile.get() != null) {
+ FileSystem fs = FileSystem.get(meanVectorFile.get().toUri(), jobConf);
+ VectorWritable meanVector =
+ ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class);
+ if (!fs.exists(meanVectorFile.get())) {
+ throw new FileNotFoundException(meanVectorFile.get().toString());
+ }
+ try (DataInputStream in = fs.open(meanVectorFile.get())){
+ meanVector.readFields(in);
+ }
+ this.meanVector = meanVector.get();
+ Preconditions.checkArgument(this.meanVector != null, "meanVector not initialized");
+ }
+
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ parameters = new ArrayList<>();
+ inverseCovarianceFile = new PathParameter(prefix, "inverseCovarianceFile", jobConf, null,
+ "Path on DFS to a file containing the inverse covariance matrix.");
+ parameters.add(inverseCovarianceFile);
+
+ matrixClass = new ClassParameter(prefix, "maxtrixClass", jobConf, DenseMatrix.class,
+ "Class<Matix> file specified in parameter inverseCovarianceFile has been serialized with.");
+ parameters.add(matrixClass);
+
+ meanVectorFile = new PathParameter(prefix, "meanVectorFile", jobConf, null,
+ "Path on DFS to a file containing the mean Vector.");
+ parameters.add(meanVectorFile);
+
+ vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class,
+ "Class file specified in parameter meanVectorFile has been serialized with.");
+ parameters.add(vectorClass);
+ }
+
+ /**
+ * @param v The vector to compute the distance to
+ * @return Mahalanobis distance of a multivariate vector
+ */
+ public double distance(Vector v) {
+ return Math.sqrt(v.minus(meanVector).dot(Algebra.mult(inverseCovarianceMatrix, v.minus(meanVector))));
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return Math.sqrt(v1.minus(v2).dot(Algebra.mult(inverseCovarianceMatrix, v1.minus(v2))));
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+ public void setInverseCovarianceMatrix(Matrix inverseCovarianceMatrix) {
+ Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ this.inverseCovarianceMatrix = inverseCovarianceMatrix;
+ }
+
+
+ /**
+ * Computes the inverse covariance from the input covariance matrix given in input.
+ *
+ * @param m A covariance matrix.
+ * @throws IllegalArgumentException if <tt>eigen values equal to 0 found</tt>.
+ */
+ public void setCovarianceMatrix(Matrix m) {
+ if (m.numRows() != m.numCols()) {
+ throw new CardinalityException(m.numRows(), m.numCols());
+ }
+ // See http://www.mlahanas.de/Math/svd.htm for details,
+ // which specifically details the case of covariance matrix inversion
+ // Complexity: O(min(nm2,mn2))
+ SingularValueDecomposition svd = new SingularValueDecomposition(m);
+ Matrix sInv = svd.getS();
+ // Inverse Diagonal Elems
+ for (int i = 0; i < sInv.numRows(); i++) {
+ double diagElem = sInv.get(i, i);
+ if (diagElem > 0.0) {
+ sInv.set(i, i, 1 / diagElem);
+ } else {
+ throw new IllegalStateException("Eigen Value equals to 0 found.");
+ }
+ }
+ inverseCovarianceMatrix = svd.getU().times(sInv.times(svd.getU().transpose()));
+ Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ }
+
+ public Matrix getInverseCovarianceMatrix() {
+ return inverseCovarianceMatrix;
+ }
+
+ public void setMeanVector(Vector meanVector) {
+ Preconditions.checkArgument(meanVector != null, "meanVector not initialized");
+ this.meanVector = meanVector;
+ }
+
+ public Vector getMeanVector() {
+ return meanVector;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
new file mode 100644
index 0000000..5c32fcf
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * This class implements a "manhattan distance" metric by summing the absolute values of the difference
+ * between each coordinate
+ */
+public class ManhattanDistanceMeasure implements DistanceMeasure {
+
+ public static double distance(double[] p1, double[] p2) {
+ double result = 0.0;
+ for (int i = 0; i < p1.length; i++) {
+ result += Math.abs(p2[i] - p1[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return v1.aggregate(v2, Functions.PLUS, Functions.MINUS_ABS);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
new file mode 100644
index 0000000..c3a48cb
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.DoubleParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implement Minkowski distance, a real-valued generalization of the
+ * integral L(n) distances: Manhattan = L1, Euclidean = L2.
+ * For high numbers of dimensions, very high exponents give more useful distances.
+ *
+ * Note: Math.pow is clever about integer-valued doubles.
+ **/
+public class MinkowskiDistanceMeasure implements DistanceMeasure {
+
+ private static final double EXPONENT = 3.0;
+
+ private List<Parameter<?>> parameters;
+ private double exponent = EXPONENT;
+
+ public MinkowskiDistanceMeasure() {
+ }
+
+ public MinkowskiDistanceMeasure(double exponent) {
+ this.exponent = exponent;
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration conf) {
+ parameters = new ArrayList<>();
+ Parameter<?> param =
+ new DoubleParameter(prefix, "exponent", conf, EXPONENT, "Exponent for Fractional Lagrange distance");
+ parameters.add(param);
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ }
+
+ public double getExponent() {
+ return exponent;
+ }
+
+ public void setExponent(double exponent) {
+ this.exponent = exponent;
+ }
+
+ /**
+ * Math.pow is clever about integer-valued doubles
+ */
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return Math.pow(v1.aggregate(v2, Functions.PLUS, Functions.minusAbsPow(exponent)), 1.0 / exponent);
+ }
+
+ // TODO: how?
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO - can this use centroidLengthSquare somehow?
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
new file mode 100644
index 0000000..66da121
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Like {@link EuclideanDistanceMeasure} but it does not take the square root.
+ * <p/>
+ * Thus, it is not actually the Euclidean Distance, but it is saves on computation when you only need the
+ * distance for comparison and don't care about the actual value as a distance.
+ */
+public class SquaredEuclideanDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return v2.getDistanceSquared(v1);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return centroidLengthSquare - 2 * v.dot(centroid) + v.getLengthSquared();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
new file mode 100644
index 0000000..cfeb119
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Tanimoto coefficient implementation.
+ *
+ * http://en.wikipedia.org/wiki/Jaccard_index
+ */
+public class TanimotoDistanceMeasure extends WeightedDistanceMeasure {
+
+ /**
+ * Calculates the distance between two vectors.
+ *
+ * The coefficient (a measure of similarity) is: T(a, b) = a.b / (|a|^2 + |b|^2 - a.b)
+ *
+ * The distance d(a,b) = 1 - T(a,b)
+ *
+ * @return 0 for perfect match, > 0 for greater distance
+ */
+ @Override
+ public double distance(Vector a, Vector b) {
+ double ab;
+ double denominator;
+ if (getWeights() != null) {
+ ab = a.times(b).aggregate(getWeights(), Functions.PLUS, Functions.MULT);
+ denominator = a.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
+ + b.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
+ - ab;
+ } else {
+ ab = b.dot(a); // b is SequentialAccess
+ denominator = a.getLengthSquared() + b.getLengthSquared() - ab;
+ }
+
+ if (denominator < ab) { // correct for fp round-off: distance >= 0
+ denominator = ab;
+ }
+ if (denominator > 0) {
+ // denominator == 0 only when dot(a,a) == dot(b,b) == dot(a,b) == 0
+ return 1.0 - ab / denominator;
+ } else {
+ return 0.0;
+ }
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
new file mode 100644
index 0000000..1acbe86
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.io.DataInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.parameters.ClassParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.common.parameters.PathParameter;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/** Abstract implementation of DistanceMeasure with support for weights. */
+public abstract class WeightedDistanceMeasure implements DistanceMeasure {
+
+ private List<Parameter<?>> parameters;
+ private Parameter<Path> weightsFile;
+ private ClassParameter vectorClass;
+ private Vector weights;
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ parameters = new ArrayList<>();
+ weightsFile = new PathParameter(prefix, "weightsFile", jobConf, null,
+ "Path on DFS to a file containing the weights.");
+ parameters.add(weightsFile);
+ vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class,
+ "Class<Vector> file specified in parameter weightsFile has been serialized with.");
+ parameters.add(vectorClass);
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ try {
+ if (weightsFile.get() != null) {
+ FileSystem fs = FileSystem.get(weightsFile.get().toUri(), jobConf);
+ VectorWritable weights =
+ ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class);
+ if (!fs.exists(weightsFile.get())) {
+ throw new FileNotFoundException(weightsFile.get().toString());
+ }
+ try (DataInputStream in = fs.open(weightsFile.get())){
+ weights.readFields(in);
+ }
+ this.weights = weights.get();
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public Vector getWeights() {
+ return weights;
+ }
+
+ public void setWeights(Vector weights) {
+ this.weights = weights;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
new file mode 100644
index 0000000..4c78d9f
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * This class implements a Euclidean distance metric by summing the square root of the squared differences
+ * between each coordinate, optionally adding weights.
+ */
+public class WeightedEuclideanDistanceMeasure extends WeightedDistanceMeasure {
+
+ @Override
+ public double distance(Vector p1, Vector p2) {
+ double result = 0;
+ Vector res = p2.minus(p1);
+ Vector theWeights = getWeights();
+ if (theWeights == null) {
+ for (Element elt : res.nonZeroes()) {
+ result += elt.get() * elt.get();
+ }
+ } else {
+ for (Element elt : res.nonZeroes()) {
+ result += elt.get() * elt.get() * theWeights.get(elt.index());
+ }
+ }
+ return Math.sqrt(result);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
new file mode 100644
index 0000000..2c280e2
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * This class implements a "Manhattan distance" metric by summing the absolute values of the difference
+ * between each coordinate, optionally with weights.
+ */
+public class WeightedManhattanDistanceMeasure extends WeightedDistanceMeasure {
+
+ @Override
+ public double distance(Vector p1, Vector p2) {
+ double result = 0;
+
+ Vector res = p2.minus(p1);
+ if (getWeights() == null) {
+ for (Element elt : res.nonZeroes()) {
+ result += Math.abs(elt.get());
+ }
+
+ } else {
+ for (Element elt : res.nonZeroes()) {
+ result += Math.abs(elt.get() * getWeights().get(elt.index()));
+ }
+ }
+
+ return result;
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
new file mode 100644
index 0000000..73cc821
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Iterator;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+
+/**
+ * An iterator that copies the values in an underlying iterator by finding an appropriate copy constructor.
+ */
+public final class CopyConstructorIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+ private Constructor<T> constructor;
+
+ public CopyConstructorIterator(Iterator<? extends T> copyFrom) {
+ this.delegate = Iterators.transform(
+ copyFrom,
+ new Function<T,T>() {
+ @Override
+ public T apply(T from) {
+ if (constructor == null) {
+ Class<T> elementClass = (Class<T>) from.getClass();
+ try {
+ constructor = elementClass.getConstructor(elementClass);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ try {
+ return constructor.newInstance(from);
+ } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
new file mode 100644
index 0000000..658c1f1
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import com.google.common.collect.AbstractIterator;
+
+/**
+ * Iterates over the integers from 0 through {@code to-1}.
+ */
+public final class CountingIterator extends AbstractIterator<Integer> {
+
+ private int count;
+ private final int to;
+
+ public CountingIterator(int to) {
+ this.to = to;
+ }
+
+ @Override
+ protected Integer computeNext() {
+ if (count < to) {
+ return count++;
+ } else {
+ return endOfData();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
new file mode 100644
index 0000000..cfc18d6
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.Charset;
+import java.util.Iterator;
+
+import com.google.common.base.Charsets;
+
+/**
+ * Iterable representing the lines of a text file. It can produce an {@link Iterator} over those lines. This
+ * assumes the text file's lines are delimited in a manner consistent with how {@link java.io.BufferedReader}
+ * defines lines.
+ *
+ */
+public final class FileLineIterable implements Iterable<String> {
+
+ private final InputStream is;
+ private final Charset encoding;
+ private final boolean skipFirstLine;
+ private final String origFilename;
+
+ /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */
+ public FileLineIterable(File file) throws IOException {
+ this(file, Charsets.UTF_8, false);
+ }
+
+ /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */
+ public FileLineIterable(File file, boolean skipFirstLine) throws IOException {
+ this(file, Charsets.UTF_8, skipFirstLine);
+ }
+
+ /** Creates a {@link FileLineIterable} over a given file, using the given encoding. */
+ public FileLineIterable(File file, Charset encoding, boolean skipFirstLine) throws IOException {
+ this(FileLineIterator.getFileInputStream(file), encoding, skipFirstLine);
+ }
+
+ public FileLineIterable(InputStream is) {
+ this(is, Charsets.UTF_8, false);
+ }
+
+ public FileLineIterable(InputStream is, boolean skipFirstLine) {
+ this(is, Charsets.UTF_8, skipFirstLine);
+ }
+
+ public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine) {
+ this.is = is;
+ this.encoding = encoding;
+ this.skipFirstLine = skipFirstLine;
+ this.origFilename = "";
+ }
+
+ public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine, String filename) {
+ this.is = is;
+ this.encoding = encoding;
+ this.skipFirstLine = skipFirstLine;
+ this.origFilename = filename;
+ }
+
+
+ @Override
+ public Iterator<String> iterator() {
+ try {
+ return new FileLineIterator(is, encoding, skipFirstLine, this.origFilename);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
new file mode 100644
index 0000000..b7cc51e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.Charset;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.ZipInputStream;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Iterates over the lines of a text file. This assumes the text file's lines are delimited in a manner
+ * consistent with how {@link BufferedReader} defines lines.
+ * <p/>
+ * This class will uncompress files that end in .zip or .gz accordingly, too.
+ */
+public final class FileLineIterator extends AbstractIterator<String> implements SkippingIterator<String>, Closeable {
+
+ private final BufferedReader reader;
+
+ private static final Logger log = LoggerFactory.getLogger(FileLineIterator.class);
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException
+ * if the file cannot be read
+ */
+ public FileLineIterator(File file) throws IOException {
+ this(file, Charsets.UTF_8, false);
+ }
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException if the file cannot be read
+ */
+ public FileLineIterator(File file, boolean skipFirstLine) throws IOException {
+ this(file, Charsets.UTF_8, skipFirstLine);
+ }
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, using the given encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException if the file cannot be read
+ */
+ public FileLineIterator(File file, Charset encoding, boolean skipFirstLine) throws IOException {
+ this(getFileInputStream(file), encoding, skipFirstLine);
+ }
+
+ public FileLineIterator(InputStream is) throws IOException {
+ this(is, Charsets.UTF_8, false);
+ }
+
+ public FileLineIterator(InputStream is, boolean skipFirstLine) throws IOException {
+ this(is, Charsets.UTF_8, skipFirstLine);
+ }
+
+ public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine) throws IOException {
+ reader = new BufferedReader(new InputStreamReader(is, encoding));
+ if (skipFirstLine) {
+ reader.readLine();
+ }
+ }
+
+ public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine, String filename)
+ throws IOException {
+ InputStream compressedInputStream;
+
+ if ("gz".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) {
+ compressedInputStream = new GZIPInputStream(is);
+ } else if ("zip".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) {
+ compressedInputStream = new ZipInputStream(is);
+ } else {
+ compressedInputStream = is;
+ }
+
+ reader = new BufferedReader(new InputStreamReader(compressedInputStream, encoding));
+ if (skipFirstLine) {
+ reader.readLine();
+ }
+ }
+
+ static InputStream getFileInputStream(File file) throws IOException {
+ InputStream is = new FileInputStream(file);
+ String name = file.getName();
+ if ("gz".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) {
+ return new GZIPInputStream(is);
+ } else if ("zip".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) {
+ return new ZipInputStream(is);
+ } else {
+ return is;
+ }
+ }
+
+ @Override
+ protected String computeNext() {
+ String line;
+ try {
+ line = reader.readLine();
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ throw new IllegalStateException(ioe);
+ }
+ return line == null ? endOfData() : line;
+ }
+
+
+ @Override
+ public void skip(int n) {
+ try {
+ for (int i = 0; i < n; i++) {
+ if (reader.readLine() == null) {
+ break;
+ }
+ }
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ endOfData();
+ Closeables.close(reader, true);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
new file mode 100644
index 0000000..1905654
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Sample a fixed number of elements from an Iterator. The results can appear in any order.
+ */
+public final class FixedSizeSamplingIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+
+ public FixedSizeSamplingIterator(int size, Iterator<T> source) {
+ List<T> buf = Lists.newArrayListWithCapacity(size);
+ int sofar = 0;
+ Random random = RandomUtils.getRandom();
+ while (source.hasNext()) {
+ T v = source.next();
+ sofar++;
+ if (buf.size() < size) {
+ buf.add(v);
+ } else {
+ int position = random.nextInt(sofar);
+ if (position < buf.size()) {
+ buf.set(position, v);
+ }
+ }
+ }
+ delegate = buf.iterator();
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
new file mode 100644
index 0000000..425b44b
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+/**
+ * Wraps an {@link Iterable} whose {@link Iterable#iterator()} returns only some subset of the elements that
+ * it would, as determined by a iterator rate parameter.
+ */
+public final class SamplingIterable<T> implements Iterable<T> {
+
+ private final Iterable<? extends T> delegate;
+ private final double samplingRate;
+
+ public SamplingIterable(Iterable<? extends T> delegate, double samplingRate) {
+ this.delegate = delegate;
+ this.samplingRate = samplingRate;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ return new SamplingIterator<>(delegate.iterator(), samplingRate);
+ }
+
+ public static <T> Iterable<T> maybeWrapIterable(Iterable<T> delegate, double samplingRate) {
+ return samplingRate >= 1.0 ? delegate : new SamplingIterable<>(delegate, samplingRate);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
new file mode 100644
index 0000000..2ba46fd
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import org.apache.commons.math3.distribution.PascalDistribution;
+import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+/**
+ * Wraps an {@link Iterator} and returns only some subset of the elements that it would, as determined by a
+ * iterator rate parameter.
+ */
+public final class SamplingIterator<T> extends AbstractIterator<T> {
+
+ private final PascalDistribution geometricDistribution;
+ private final Iterator<? extends T> delegate;
+
+ public SamplingIterator(Iterator<? extends T> delegate, double samplingRate) {
+ this(RandomUtils.getRandom(), delegate, samplingRate);
+ }
+
+ public SamplingIterator(RandomWrapper random, Iterator<? extends T> delegate, double samplingRate) {
+ Preconditions.checkNotNull(delegate);
+ Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0,
+ "Must be: 0.0 < samplingRate <= 1.0. But samplingRate = " + samplingRate);
+ // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+ geometricDistribution = new PascalDistribution(random.getRandomGenerator(), 1, samplingRate);
+ this.delegate = delegate;
+ }
+
+ @Override
+ protected T computeNext() {
+ int toSkip = geometricDistribution.sample();
+ if (delegate instanceof SkippingIterator<?>) {
+ SkippingIterator<? extends T> skippingDelegate = (SkippingIterator<? extends T>) delegate;
+ skippingDelegate.skip(toSkip);
+ if (skippingDelegate.hasNext()) {
+ return skippingDelegate.next();
+ }
+ } else {
+ for (int i = 0; i < toSkip && delegate.hasNext(); i++) {
+ delegate.next();
+ }
+ if (delegate.hasNext()) {
+ return delegate.next();
+ }
+ }
+ return endOfData();
+ }
+
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
new file mode 100644
index 0000000..c4ddf7b
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Sample a fixed number of elements from an Iterator. The results will appear in the original order at some
+ * cost in time and memory relative to a FixedSizeSampler.
+ */
+public class StableFixedSizeSamplingIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+
+ public StableFixedSizeSamplingIterator(int size, Iterator<T> source) {
+ List<Pair<Integer,T>> buf = Lists.newArrayListWithCapacity(size);
+ int sofar = 0;
+ Random random = RandomUtils.getRandom();
+ while (source.hasNext()) {
+ T v = source.next();
+ sofar++;
+ if (buf.size() < size) {
+ buf.add(new Pair<>(sofar, v));
+ } else {
+ int position = random.nextInt(sofar);
+ if (position < buf.size()) {
+ buf.set(position, new Pair<>(sofar, v));
+ }
+ }
+ }
+
+ Collections.sort(buf);
+ delegate = Iterators.transform(buf.iterator(),
+ new Function<Pair<Integer,T>,T>() {
+ @Override
+ public T apply(Pair<Integer,T> from) {
+ return from.getSecond();
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
new file mode 100644
index 0000000..73b841e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.common.Pair;
+
+public class StringRecordIterator extends ForwardingIterator<Pair<List<String>,Long>> {
+
+ private static final Long ONE = 1L;
+
+ private final Pattern splitter;
+ private final Iterator<Pair<List<String>,Long>> delegate;
+
+ public StringRecordIterator(Iterable<String> stringIterator, String pattern) {
+ this.splitter = Pattern.compile(pattern);
+ delegate = Iterators.transform(
+ stringIterator.iterator(),
+ new Function<String,Pair<List<String>,Long>>() {
+ @Override
+ public Pair<List<String>,Long> apply(String from) {
+ String[] items = splitter.split(from);
+ return new Pair<>(Arrays.asList(items), ONE);
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<Pair<List<String>,Long>> delegate() {
+ return delegate;
+ }
+
+}