You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:44 UTC
[16/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
new file mode 100644
index 0000000..1490761
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+/**
+ * <p>
+ * A {@link Rescorer} simply assigns a new "score" to a thing like an ID of an item or user which a
+ * {@link Recommender} is considering returning as a top recommendation. It may be used to arbitrarily re-rank
+ * the results according to application-specific logic before returning recommendations. For example, an
+ * application may want to boost the score of items in a certain category just for one request.
+ * </p>
+ *
+ * <p>
+ * A {@link Rescorer} can also exclude a thing from consideration entirely by returning {@code true} from
+ * {@link #isFiltered(Object)}.
+ * </p>
+ */
+public interface Rescorer<T> {
+
+ /**
+ * @param thing
+ * thing to rescore
+ * @param originalScore
+ * original score
+ * @return modified score, or {@link Double#NaN} to indicate that this should be excluded entirely
+ */
+ double rescore(T thing, double originalScore);
+
+ /**
+ * Returns {@code true} to exclude the given thing.
+ *
+ * @param thing
+ * the thing to filter
+ * @return {@code true} to exclude, {@code false} otherwise
+ */
+ boolean isFiltered(T thing);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
new file mode 100644
index 0000000..b48593a
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.common.LongPair;
+
+/**
+ * <p>
+ * Interface implemented by "user-based" recommenders.
+ * </p>
+ */
+public interface UserBasedRecommender extends Recommender {
+
+ /**
+ * @param userID
+ * ID of user for which to find most similar other users
+ * @param howMany
+ * desired number of most similar users to find
+ * @return users most similar to the given user
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ long[] mostSimilarUserIDs(long userID, int howMany) throws TasteException;
+
+ /**
+ * @param userID
+ * ID of user for which to find most similar other users
+ * @param howMany
+ * desired number of most similar users to find
+ * @param rescorer
+ * {@link Rescorer} which can adjust user-user similarity estimates used to determine most similar
+ * users
+ * @return IDs of users most similar to the given user
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ long[] mostSimilarUserIDs(long userID, int howMany, Rescorer<LongPair> rescorer) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
new file mode 100644
index 0000000..814610b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface define a notion of similarity between two items. Implementations should
+ * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity.
+ * </p>
+ *
+ * @see UserSimilarity
+ */
+public interface ItemSimilarity extends Refreshable {
+
+ /**
+ * <p>
+ * Returns the degree of similarity, of two items, based on the preferences that users have expressed for
+ * the items.
+ * </p>
+ *
+ * @param itemID1 first item ID
+ * @param itemID2 second item ID
+ * @return similarity between the items, in [-1,1] or {@link Double#NaN} similarity is unknown
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if either item is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double itemSimilarity(long itemID1, long itemID2) throws TasteException;
+
+ /**
+ * <p>A bulk-get version of {@link #itemSimilarity(long, long)}.</p>
+ *
+ * @param itemID1 first item ID
+ * @param itemID2s second item IDs to compute similarity with
+ * @return similarity between itemID1 and other items
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if any item is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException;
+
+ /**
+ * @return all IDs of similar items, in no particular order
+ */
+ long[] allSimilarItemIDs(long itemID) throws TasteException;
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
new file mode 100644
index 0000000..76bb328
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface compute an inferred preference for a user and an item that the user has
+ * not expressed any preference for. This might be an average of other preferences scores from that user, for
+ * example. This technique is sometimes called "default voting".
+ * </p>
+ */
+public interface PreferenceInferrer extends Refreshable {
+
+ /**
+ * <p>
+ * Infers the given user's preference value for an item.
+ * </p>
+ *
+ * @param userID
+ * ID of user to infer preference for
+ * @param itemID
+ * item ID to infer preference for
+ * @return inferred preference
+ * @throws TasteException
+ * if an error occurs while inferring
+ */
+ float inferPreference(long userID, long itemID) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
new file mode 100644
index 0000000..bd53c51
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface define a notion of similarity between two users. Implementations should
+ * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity.
+ * </p>
+ *
+ * @see ItemSimilarity
+ */
+public interface UserSimilarity extends Refreshable {
+
+ /**
+ * <p>
+ * Returns the degree of similarity, of two users, based on the their preferences.
+ * </p>
+ *
+ * @param userID1 first user ID
+ * @param userID2 second user ID
+ * @return similarity between the users, in [-1,1] or {@link Double#NaN} similarity is unknown
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if either user is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double userSimilarity(long userID1, long userID2) throws TasteException;
+
+ // Should we implement userSimilarities() like ItemSimilarity.itemSimilarities()?
+
+ /**
+ * <p>
+ * Attaches a {@link PreferenceInferrer} to the {@link UserSimilarity} implementation.
+ * </p>
+ *
+ * @param inferrer {@link PreferenceInferrer}
+ */
+ void setPreferenceInferrer(PreferenceInferrer inferrer);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
new file mode 100644
index 0000000..b934d0c
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+
+import java.io.IOException;
+
+public abstract class BatchItemSimilarities {
+
+ private final ItemBasedRecommender recommender;
+ private final int similarItemsPerItem;
+
+ /**
+ * @param recommender recommender to use
+ * @param similarItemsPerItem number of similar items to compute per item
+ */
+ protected BatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem) {
+ this.recommender = recommender;
+ this.similarItemsPerItem = similarItemsPerItem;
+ }
+
+ protected ItemBasedRecommender getRecommender() {
+ return recommender;
+ }
+
+ protected int getSimilarItemsPerItem() {
+ return similarItemsPerItem;
+ }
+
+ /**
+ * @param degreeOfParallelism number of threads to use for the computation
+ * @param maxDurationInHours maximum duration of the computation
+ * @param writer {@link SimilarItemsWriter} used to persist the results
+ * @return the number of similarities precomputed
+ * @throws IOException
+ * @throws RuntimeException if the computation takes longer than maxDurationInHours
+ */
+ public abstract int computeItemSimilarities(int degreeOfParallelism, int maxDurationInHours,
+ SimilarItemsWriter writer) throws IOException;
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
new file mode 100644
index 0000000..5d40051
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import com.google.common.primitives.Doubles;
+
+import java.util.Comparator;
+
+/**
+ * Modeling similarity towards another item
+ */
+public class SimilarItem {
+
+ public static final Comparator<SimilarItem> COMPARE_BY_SIMILARITY = new Comparator<SimilarItem>() {
+ @Override
+ public int compare(SimilarItem s1, SimilarItem s2) {
+ return Doubles.compare(s1.similarity, s2.similarity);
+ }
+ };
+
+ private long itemID;
+ private double similarity;
+
+ public SimilarItem(long itemID, double similarity) {
+ set(itemID, similarity);
+ }
+
+ public void set(long itemID, double similarity) {
+ this.itemID = itemID;
+ this.similarity = similarity;
+ }
+
+ public long getItemID() {
+ return itemID;
+ }
+
+ public double getSimilarity() {
+ return similarity;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
new file mode 100644
index 0000000..057e996
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import com.google.common.collect.UnmodifiableIterator;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * Compact representation of all similar items for an item
+ */
+public class SimilarItems {
+
+ private final long itemID;
+ private final long[] similarItemIDs;
+ private final double[] similarities;
+
+ public SimilarItems(long itemID, List<RecommendedItem> similarItems) {
+ this.itemID = itemID;
+
+ int numSimilarItems = similarItems.size();
+ similarItemIDs = new long[numSimilarItems];
+ similarities = new double[numSimilarItems];
+
+ for (int n = 0; n < numSimilarItems; n++) {
+ similarItemIDs[n] = similarItems.get(n).getItemID();
+ similarities[n] = similarItems.get(n).getValue();
+ }
+ }
+
+ public long getItemID() {
+ return itemID;
+ }
+
+ public int numSimilarItems() {
+ return similarItemIDs.length;
+ }
+
+ public Iterable<SimilarItem> getSimilarItems() {
+ return new Iterable<SimilarItem>() {
+ @Override
+ public Iterator<SimilarItem> iterator() {
+ return new SimilarItemsIterator();
+ }
+ };
+ }
+
+ private class SimilarItemsIterator extends UnmodifiableIterator<SimilarItem> {
+
+ private int index = -1;
+
+ @Override
+ public boolean hasNext() {
+ return index < (similarItemIDs.length - 1);
+ }
+
+ @Override
+ public SimilarItem next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ index++;
+ return new SimilarItem(similarItemIDs[index], similarities[index]);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
new file mode 100644
index 0000000..35d6bfe
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+/**
+ * Used to persist the results of a batch item similarity computation
+ * conducted with a {@link BatchItemSimilarities} implementation
+ */
+public interface SimilarItemsWriter extends Closeable {
+
+ void open() throws IOException;
+
+ void add(SimilarItems similarItems) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
new file mode 100644
index 0000000..efd233f
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
@@ -0,0 +1,248 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Defines the interface for classifiers that take a vector as input. This is
+ * implemented as an abstract class so that it can implement a number of handy
+ * convenience methods related to classification of vectors.
+ *
+ * <p>
+ * A classifier takes an input vector and calculates the scores (usually
+ * probabilities) that the input vector belongs to one of {@code n}
+ * categories. In {@code AbstractVectorClassifier} each category is denoted
+ * by an integer {@code c} between {@code 0} and {@code n-1}
+ * (inclusive).
+ *
+ * <p>
+ * New users should start by looking at {@link #classifyFull} (not {@link #classify}).
+ *
+ */
+public abstract class AbstractVectorClassifier {
+
+ /** Minimum allowable log likelihood value. */
+ public static final double MIN_LOG_LIKELIHOOD = -100.0;
+
+ /**
+ * Returns the number of categories that a target variable can be assigned to.
+ * A vector classifier will encode it's output as an integer from
+ * {@code 0} to {@code numCategories()-1} (inclusive).
+ *
+ * @return The number of categories.
+ */
+ public abstract int numCategories();
+
+ /**
+ * Compute and return a vector containing {@code n-1} scores, where
+ * {@code n} is equal to {@code numCategories()}, given an input
+ * vector {@code instance}. Higher scores indicate that the input vector
+ * is more likely to belong to that category. The categories are denoted by
+ * the integers {@code 0} through {@code n-1} (inclusive), and the
+ * scores in the returned vector correspond to categories 1 through
+ * {@code n-1} (leaving out category 0). It is assumed that the score for
+ * category 0 is one minus the sum of the scores in the returned vector.
+ *
+ * @param instance A feature vector to be classified.
+ * @return A vector of probabilities in 1 of {@code n-1} encoding.
+ */
+ public abstract Vector classify(Vector instance);
+
+ /**
+ * Compute and return a vector of scores before applying the inverse link
+ * function. For logistic regression and other generalized linear models, this
+ * is just the linear part of the classification.
+ *
+ * <p>
+ * The implementation of this method provided by {@code AbstractVectorClassifier} throws an
+ * {@link UnsupportedOperationException}. Your subclass must explicitly override this method to support
+ * this operation.
+ *
+ * @param features A feature vector to be classified.
+ * @return A vector of scores. If transformed by the link function, these will become probabilities.
+ */
+ public Vector classifyNoLink(Vector features) {
+ throw new UnsupportedOperationException(this.getClass().getName()
+ + " doesn't support classification without a link");
+ }
+
+ /**
+ * Classifies a vector in the special case of a binary classifier where
+ * {@link #classify(Vector)} would return a vector with only one element. As
+ * such, using this method can avoid the allocation of a vector.
+ *
+ * @param instance The feature vector to be classified.
+ * @return The score for category 1.
+ *
+ * @see #classify(Vector)
+ */
+ public abstract double classifyScalar(Vector instance);
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive).
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier} and {@code input} is a
+ * {@code Vector} of features describing an element to be classified,
+ * then the following code could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(input);<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * <p>
+ * If an {@code n-1} encoding is acceptable, and allocation performance
+ * is an issue, then the {@link #classify(Vector)} method is probably better
+ * to use.
+ *
+ * @see #classify(Vector)
+ * @see #classifyFull(Vector r, Vector instance)
+ *
+ * @param instance A vector of features to be classified.
+ * @return A vector of probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(new DenseVector(numCategories()), instance);
+ }
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive). The
+ * main difference between this method and {@link #classifyFull(Vector)} is
+ * that this method allows a user to provide a previously allocated
+ * {@code Vector r} to store the returned scores.
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier}, {@code result} is a non-null
+ * {@code Vector}, and {@code input} is a {@code Vector} of
+ * features describing an element to be classified, then the following code
+ * could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(result, input); // Notice that scores == result<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * @param r Where to put the results.
+ * @param instance A vector of features to be classified.
+ * @return A vector of scores/probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector r, Vector instance) {
+ r.viewPart(1, numCategories() - 1).assign(classify(instance));
+ r.setQuick(0, 1.0 - r.zSum());
+ return r;
+ }
+
+
+ /**
+ * Returns n-1 probabilities, one for each categories 1 through
+ * {@code n-1}, for each row of a matrix, where {@code n} is equal
+ * to {@code numCategories()}. The probability of the missing 0-th
+ * category is 1 - rowSum(this result).
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classify(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories() - 1);
+ for (int row = 0; row < data.numRows(); row++) {
+ r.assignRow(row, classify(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category.
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classifyFull(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories());
+ for (int row = 0; row < data.numRows(); row++) {
+ classifyFull(r.viewRow(row), data.viewRow(row));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a vector of probabilities of category 1, one for each row
+ * of a matrix. This only makes sense if there are exactly two categories, but
+ * calling this method in that case can save a number of vector allocations.
+ *
+ * @param data The matrix whose rows are vectors to classify
+ * @return A vector of scores, with one value per row of the input matrix.
+ */
+ public Vector classifyScalar(Matrix data) {
+ Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+ Vector r = new DenseVector(data.numRows());
+ for (int row = 0; row < data.numRows(); row++) {
+ r.set(row, classifyScalar(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a measure of how good the classification for a particular example
+ * actually is.
+ *
+ * @param actual The correct category for the example.
+ * @param data The vector to be classified.
+ * @return The log likelihood of the correct answer as estimated by the current model. This will always be <= 0
+ * and larger (closer to 0) indicates better accuracy. In order to simplify code that maintains eunning averages,
+ * we bound this value at -100.
+ */
+ public double logLikelihood(int actual, Vector data) {
+ if (numCategories() == 2) {
+ double p = classifyScalar(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p));
+ }
+ } else {
+ Vector p = classify(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p.get(actual - 1)));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p.zSum()));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
new file mode 100644
index 0000000..29eaa0d
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+/**
+ * Result of a document classification. The label and the associated score (usually probabilty)
+ */
+public class ClassifierResult {
+
+ private String label;
+ private double score;
+ private double logLikelihood = Double.MAX_VALUE;
+
+ public ClassifierResult() { }
+
+ public ClassifierResult(String label, double score) {
+ this.label = label;
+ this.score = score;
+ }
+
+ public ClassifierResult(String label) {
+ this.label = label;
+ }
+
+ public ClassifierResult(String label, double score, double logLikelihood) {
+ this.label = label;
+ this.score = score;
+ this.logLikelihood = logLikelihood;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public String getLabel() {
+ return label;
+ }
+
+ public double getScore() {
+ return score;
+ }
+
+ public void setLabel(String label) {
+ this.label = label;
+ }
+
+ public void setScore(double score) {
+ this.score = score;
+ }
+
+ @Override
+ public String toString() {
+ return "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
new file mode 100644
index 0000000..73ba521
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
@@ -0,0 +1,444 @@
+/**
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
+ *
+ * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
+ *
+ * See http://en.wikipedia.org/wiki/Confusion_matrix for background
+ */
+public class ConfusionMatrix {
+ private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
+ private final Map<String,Integer> labelMap = new LinkedHashMap<>();
+ private final int[][] confusionMatrix;
+ private int samples = 0;
+ private String defaultLabel = "unknown";
+
+ public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
+ confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
+ this.defaultLabel = defaultLabel;
+ int i = 0;
+ for (String label : labels) {
+ labelMap.put(label, i++);
+ }
+ labelMap.put(defaultLabel, i);
+ }
+
+ public ConfusionMatrix(Matrix m) {
+ confusionMatrix = new int[m.numRows()][m.numRows()];
+ setMatrix(m);
+ }
+
+ public int[][] getConfusionMatrix() {
+ return confusionMatrix;
+ }
+
+ public Collection<String> getLabels() {
+ return Collections.unmodifiableCollection(labelMap.keySet());
+ }
+
+ private int numLabels() {
+ return labelMap.size();
+ }
+
+ public double getAccuracy(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ if (i == labelId) {
+ correct += confusionMatrix[labelId][i];
+ }
+ }
+ return 100.0 * correct / labelTotal;
+ }
+
+ // Producer accuracy
+ public double getAccuracy() {
+ int total = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ for (int j = 0; j < numLabels(); j++) {
+ total += confusionMatrix[i][j];
+ if (i == j) {
+ correct += confusionMatrix[i][j];
+ }
+ }
+ }
+ return 100.0 * correct / total;
+ }
+
+ /** Sum of true positives and false negatives */
+ private int getActualNumberOfTestExamplesForClass(String label) {
+ int labelId = labelMap.get(label);
+ int sum = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ sum += confusionMatrix[labelId][i];
+ }
+ return sum;
+ }
+
+ public double getPrecision(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falsePositives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falsePositives += confusionMatrix[i][labelId];
+ }
+
+ if (truePositives + falsePositives == 0) {
+ return 0;
+ }
+
+ return ((double) truePositives) / (truePositives + falsePositives);
+ }
+
+ public double getWeightedPrecision() {
+ double[] precisions = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ precisions[index] = getPrecision(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(precisions, weights);
+ }
+
+ public double getRecall(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falseNegatives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falseNegatives += confusionMatrix[labelId][i];
+ }
+ if (truePositives + falseNegatives == 0) {
+ return 0;
+ }
+ return ((double) truePositives) / (truePositives + falseNegatives);
+ }
+
+ public double getWeightedRecall() {
+ double[] recalls = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ recalls[index] = getRecall(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(recalls, weights);
+ }
+
+ public double getF1score(String label) {
+ double precision = getPrecision(label);
+ double recall = getRecall(label);
+ if (precision + recall == 0) {
+ return 0;
+ }
+ return 2 * precision * recall / (precision + recall);
+ }
+
+ public double getWeightedF1score() {
+ double[] f1Scores = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ f1Scores[index] = getF1score(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(f1Scores, weights);
+ }
+
+ // User accuracy
+ public double getReliability() {
+ int count = 0;
+ double accuracy = 0;
+ for (String label: labelMap.keySet()) {
+ if (!label.equals(defaultLabel)) {
+ accuracy += getAccuracy(label);
+ }
+ count++;
+ }
+ return accuracy / count;
+ }
+
+ /**
+ * Accuracy v.s. randomly classifying all samples.
+ * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+ * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
+ * Educational And Psychological Measurement 20:37-46.
+ *
+ * Formula and variable names from:
+ * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+ *
+ * @return double
+ */
+ public double getKappa() {
+ double a = 0.0;
+ double b = 0.0;
+ for (int i = 0; i < confusionMatrix.length; i++) {
+ a += confusionMatrix[i][i];
+ double br = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ br += confusionMatrix[i][j];
+ }
+ double bc = 0;
+ for (int[] vec : confusionMatrix) {
+ bc += vec[i];
+ }
+ b += br * bc;
+ }
+ return (samples * a - b) / (samples * samples - b);
+ }
+
+ /**
+ * Standard deviation of normalized producer accuracy
+ * Not a standard score
+ * @return double
+ */
+ public RunningAverageAndStdDev getNormalizedStats() {
+ RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
+ for (int d = 0; d < confusionMatrix.length; d++) {
+ double total = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ total += confusionMatrix[d][j];
+ }
+ summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
+ }
+
+ return summer;
+ }
+
+ public int getCorrect(String label) {
+ int labelId = labelMap.get(label);
+ return confusionMatrix[labelId][labelId];
+ }
+
+ public int getTotal(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ for (int i = 0; i < labelMap.size(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ }
+ return labelTotal;
+ }
+
+ public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ samples++;
+ incrementCount(correctLabel, classifiedResult.getLabel());
+ }
+
+ public void addInstance(String correctLabel, String classifiedLabel) {
+ samples++;
+ incrementCount(correctLabel, classifiedLabel);
+ }
+
+ public int getCount(String correctLabel, String classifiedLabel) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return 0;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ return confusionMatrix[correctId][classifiedId];
+ }
+
+ public void putCount(String correctLabel, String classifiedLabel, int count) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
+ samples++;
+ }
+ confusionMatrix[correctId][classifiedId] = count;
+ }
+
+ public String getDefaultLabel() {
+ return defaultLabel;
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel, int count) {
+ putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel) {
+ incrementCount(correctLabel, classifiedLabel, 1);
+ }
+
+ public ConfusionMatrix merge(ConfusionMatrix b) {
+ Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
+ for (String correctLabel : this.labelMap.keySet()) {
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
+ }
+ }
+ return this;
+ }
+
+ public Matrix getMatrix() {
+ int length = confusionMatrix.length;
+ Matrix m = new DenseMatrix(length, length);
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ m.set(r, c, confusionMatrix[r][c]);
+ }
+ }
+ Map<String,Integer> labels = new HashMap<>();
+ for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+ labels.put(entry.getKey(), entry.getValue());
+ }
+ m.setRowLabelBindings(labels);
+ m.setColumnLabelBindings(labels);
+ return m;
+ }
+
+ public void setMatrix(Matrix m) {
+ int length = confusionMatrix.length;
+ if (m.numRows() != m.numCols()) {
+ throw new IllegalArgumentException(
+ "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
+ }
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+ }
+ }
+ Map<String,Integer> labels = m.getRowLabelBindings();
+ if (labels == null) {
+ labels = m.getColumnLabelBindings();
+ }
+ if (labels != null) {
+ String[] sorted = sortLabels(labels);
+ verifyLabels(length, sorted);
+ labelMap.clear();
+ for (int i = 0; i < length; i++) {
+ labelMap.put(sorted[i], i);
+ }
+ }
+ }
+
+ private static String[] sortLabels(Map<String,Integer> labels) {
+ String[] sorted = new String[labels.size()];
+ for (Map.Entry<String,Integer> entry : labels.entrySet()) {
+ sorted[entry.getValue()] = entry.getKey();
+ }
+ return sorted;
+ }
+
+ private static void verifyLabels(int length, String[] sorted) {
+ Preconditions.checkArgument(sorted.length == length, "One label, one row");
+ for (int i = 0; i < length; i++) {
+ if (sorted[i] == null) {
+ Preconditions.checkArgument(false, "One label, one row");
+ }
+ }
+ }
+
+ /**
+ * This is overloaded. toString() is not a formatted report you print for a manager :)
+ * Assume that if there are no default assignments, the default feature was not used
+ */
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder(200);
+ returnString.append("=======================================================").append('\n');
+ returnString.append("Confusion Matrix\n");
+ returnString.append("-------------------------------------------------------").append('\n');
+
+ int unclassified = getTotal(defaultLabel);
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+
+ returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
+ }
+
+ returnString.append("<--Classified as").append('\n');
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ String correctLabel = entry.getKey();
+ int labelTotal = 0;
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ returnString.append(
+ StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
+ labelTotal += getCount(correctLabel, classifiedLabel);
+ }
+ returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
+ .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
+ .append(" = ").append(correctLabel).append('\n');
+ }
+ if (unclassified > 0) {
+ returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
+ }
+ returnString.append('\n');
+ return returnString.toString();
+ }
+
+ static String getSmallLabel(int i) {
+ int val = i;
+ StringBuilder returnString = new StringBuilder();
+ do {
+ int n = val % 26;
+ returnString.insert(0, (char) ('a' + n));
+ val /= 26;
+ } while (val > 0);
+ return returnString.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
new file mode 100644
index 0000000..af1d5e7
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.Vector;
+
+import java.io.Closeable;
+
+/**
+ * The simplest interface for online learning algorithms.
+ */
+public interface OnlineLearner extends Closeable {
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary, then
+ * the training examples will be presented in the same order. This is because the order of
+ * training examples may be used to assign records to different data splits for evaluation by
+ * cross-validation. Without the order invariance, records might be assigned to training and test
+ * splits and error estimates could be seriously affected.
+ * <p/>
+ * If re-ordering is necessary, then using the alternative API which allows a tracking key to be
+ * added to the training example can be used.
+ *
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param groupKey An optional value that allows examples to be grouped in the computation of
+ * the update to the model.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, String groupKey, int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, int actual, Vector instance);
+
+ /**
+ * Prepares the classifier for classification and deallocates any temporary data structures.
+ *
+ * An online classifier should be able to accept more training after being closed, but
+ * closing the classifier may make classification more efficient.
+ */
+ @Override
+ void close();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
new file mode 100644
index 0000000..35c11ee
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+
+import org.apache.commons.lang3.StringUtils;
+
+/**
+ * ResultAnalyzer captures the classification statistics and displays in a tabular manner
+ */
+public class RegressionResultAnalyzer {
+
+ private static class Result {
+ private final double actual;
+ private final double result;
+ Result(double actual, double result) {
+ this.actual = actual;
+ this.result = result;
+ }
+ double getActual() {
+ return actual;
+ }
+ double getResult() {
+ return result;
+ }
+ }
+
+ private List<Result> results;
+
+ /**
+ *
+ * @param actual
+ * The actual answer
+ * @param result
+ * The regression result
+ */
+ public void addInstance(double actual, double result) {
+ if (results == null) {
+ results = new ArrayList<>();
+ }
+ results.add(new Result(actual, result));
+ }
+
+ /**
+ *
+ * @param results
+ * The results table
+ */
+ public void setInstances(double[][] results) {
+ for (double[] res : results) {
+ addInstance(res[0], res[1]);
+ }
+ }
+
+ @Override
+ public String toString() {
+ double sumActual = 0.0;
+ double sumActualSquared = 0.0;
+ double sumResult = 0.0;
+ double sumResultSquared = 0.0;
+ double sumActualResult = 0.0;
+ double sumAbsolute = 0.0;
+ double sumAbsoluteSquared = 0.0;
+ int predictable = 0;
+ int unpredictable = 0;
+
+ for (Result res : results) {
+ double actual = res.getActual();
+ double result = res.getResult();
+ if (Double.isNaN(result)) {
+ unpredictable++;
+ } else {
+ sumActual += actual;
+ sumActualSquared += actual * actual;
+ sumResult += result;
+ sumResultSquared += result * result;
+ sumActualResult += actual * result;
+ double absolute = Math.abs(actual - result);
+ sumAbsolute += absolute;
+ sumAbsoluteSquared += absolute * absolute;
+ predictable++;
+ }
+ }
+
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ if (predictable > 0) {
+ double varActual = sumActualSquared - sumActual * sumActual / predictable;
+ double varResult = sumResultSquared - sumResult * sumResult / predictable;
+ double varCo = sumActualResult - sumActual * sumResult / predictable;
+
+ double correlation;
+ if (varActual * varResult <= 0) {
+ correlation = 0.0;
+ } else {
+ correlation = varCo / Math.sqrt(varActual * varResult);
+ }
+
+ Locale.setDefault(Locale.US);
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)),
+ 10)).append('\n');
+ }
+ returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n');
+ returnString.append('\n');
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
new file mode 100644
index 0000000..1711f19
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.Collection;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
+public class ResultAnalyzer {
+
+ private final ConfusionMatrix confusionMatrix;
+ private final OnlineSummarizer summarizer;
+ private boolean hasLL;
+
+ /*
+ * === Summary ===
+ *
+ * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic
+ * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root
+ * relative squared error 37.2742 % Total Number of Instances 683
+ */
+ private int correctlyClassified;
+ private int incorrectlyClassified;
+
+ public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
+ confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
+ summarizer = new OnlineSummarizer();
+ }
+
+ public ConfusionMatrix getConfusionMatrix() {
+ return this.confusionMatrix;
+ }
+
+ /**
+ *
+ * @param correctLabel
+ * The correct label
+ * @param classifiedResult
+ * The classified result
+ * @return whether the instance was correct or not
+ */
+ public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ boolean result = correctLabel.equals(classifiedResult.getLabel());
+ if (result) {
+ correctlyClassified++;
+ } else {
+ incorrectlyClassified++;
+ }
+ confusionMatrix.addInstance(correctLabel, classifiedResult);
+ if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) {
+ summarizer.add(classifiedResult.getLogLikelihood());
+ hasLL = true;
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append('\n');
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+ int totalClassified = correctlyClassified + incorrectlyClassified;
+ double percentageCorrect = (double) 100 * correctlyClassified / totalClassified;
+ double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified;
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n');
+ returnString.append('\n');
+
+ returnString.append(confusionMatrix);
+ returnString.append("=======================================================\n");
+ returnString.append("Statistics\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
+ returnString.append(StringUtils.rightPad("Kappa", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Accuracy", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted precision", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted recall", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
+
+ if (hasLL) {
+ returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n');
+ }
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
new file mode 100644
index 0000000..f79a429
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Builds a tree using bagging
+ */
+@Deprecated
+public class Bagging {
+
+ private static final Logger log = LoggerFactory.getLogger(Bagging.class);
+
+ private final TreeBuilder treeBuilder;
+
+ private final Data data;
+
+ private final boolean[] sampled;
+
+ public Bagging(TreeBuilder treeBuilder, Data data) {
+ this.treeBuilder = treeBuilder;
+ this.data = data;
+ sampled = new boolean[data.size()];
+ }
+
+ /**
+ * Builds one tree
+ */
+ public Node build(Random rng) {
+ log.debug("Bagging...");
+ Arrays.fill(sampled, false);
+ Data bag = data.bagging(rng, sampled);
+
+ log.debug("Building...");
+ return treeBuilder.build(rng, bag);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
new file mode 100644
index 0000000..c94292c
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
@@ -0,0 +1,174 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+
+/**
+ * Utility class that contains various helper methods
+ */
+@Deprecated
+public final class DFUtils {
+
+ private DFUtils() {
+ }
+
+ /**
+ * Writes an Node[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, Node[] array) throws IOException {
+ out.writeInt(array.length);
+ for (Node w : array) {
+ w.write(out);
+ }
+ }
+
+ /**
+ * Reads a Node[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static Node[] readNodeArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ Node[] nodes = new Node[length];
+ for (int index = 0; index < length; index++) {
+ nodes[index] = Node.read(in);
+ }
+
+ return nodes;
+ }
+
+ /**
+ * Writes a double[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, double[] array) throws IOException {
+ out.writeInt(array.length);
+ for (double value : array) {
+ out.writeDouble(value);
+ }
+ }
+
+ /**
+ * Reads a double[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static double[] readDoubleArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ double[] array = new double[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readDouble();
+ }
+
+ return array;
+ }
+
+ /**
+ * Writes an int[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, int[] array) throws IOException {
+ out.writeInt(array.length);
+ for (int value : array) {
+ out.writeInt(value);
+ }
+ }
+
+ /**
+ * Reads an int[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static int[] readIntArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ int[] array = new int[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readInt();
+ }
+
+ return array;
+ }
+
+ /**
+ * Return a list of all files in the output directory
+ * @throws IOException if no file is found
+ */
+ public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException {
+ List<Path> outputFiles = new ArrayList<>();
+ for (FileStatus s : fs.listStatus(outputPath, PathFilters.logsCRCFilter())) {
+ if (!s.isDir() && !s.getPath().getName().startsWith("_")) {
+ outputFiles.add(s.getPath());
+ }
+ }
+ if (outputFiles.isEmpty()) {
+ throw new IOException("No output found !");
+ }
+ return outputFiles.toArray(new Path[outputFiles.size()]);
+ }
+
+ /**
+ * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis"
+ */
+ public static String elapsedTime(long milli) {
+ long seconds = milli / 1000;
+ milli %= 1000;
+
+ long minutes = seconds / 60;
+ seconds %= 60;
+
+ long hours = minutes / 60;
+ minutes %= 60;
+
+ return hours + "h " + minutes + "m " + seconds + "s " + milli;
+ }
+
+ public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+
+ try (FSDataOutputStream out = fs.create(path)) {
+ writable.write(out);
+ }
+ }
+
+ /**
+ * Write a string to a path.
+ * @param conf From which the file system will be picked
+ * @param path Where the string will be written
+ * @param string The string to write
+ * @throws IOException if things go poorly
+ */
+ public static void storeString(Configuration conf, Path path, String string) throws IOException {
+ try (DataOutputStream out = path.getFileSystem(conf).create(path)) {
+ out.write(string.getBytes(Charset.defaultCharset()));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
new file mode 100644
index 0000000..c11cf34
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
@@ -0,0 +1,241 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Represents a forest of decision trees.
+ */
+@Deprecated
+public class DecisionForest implements Writable {
+
+ private final List<Node> trees;
+
+ private DecisionForest() {
+ trees = new ArrayList<>();
+ }
+
+ public DecisionForest(List<Node> trees) {
+ Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty");
+
+ this.trees = trees;
+ }
+
+ List<Node> getTrees() {
+ return trees;
+ }
+
+ /**
+ * Classifies the data and calls callback for each classification
+ */
+ public void classify(Data data, double[][] predictions) {
+ Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
+
+ if (data.isEmpty()) {
+ return; // nothing to classify
+ }
+
+ int treeId = 0;
+ for (Node tree : trees) {
+ for (int index = 0; index < data.size(); index++) {
+ if (predictions[index] == null) {
+ predictions[index] = new double[trees.size()];
+ }
+ predictions[index][treeId] = tree.classify(data.get(index));
+ }
+ treeId++;
+ }
+ }
+
+ /**
+ * predicts the label for the instance
+ *
+ * @param rng
+ * Random number generator, used to break ties randomly
+ * @return NaN if the label cannot be predicted
+ */
+ public double classify(Dataset dataset, Random rng, Instance instance) {
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ double sum = 0;
+ int cnt = 0;
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ sum += prediction;
+ cnt++;
+ }
+ }
+
+ if (cnt > 0) {
+ return sum / cnt;
+ } else {
+ return Double.NaN;
+ }
+ } else {
+ int[] predictions = new int[dataset.nblabels()];
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ predictions[(int) prediction]++;
+ }
+ }
+
+ if (DataUtils.sum(predictions) == 0) {
+ return Double.NaN; // no prediction available
+ }
+
+ return DataUtils.maxindex(rng, predictions);
+ }
+ }
+
+ /**
+ * @return Mean number of nodes per tree
+ */
+ public long meanNbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum / trees.size();
+ }
+
+ /**
+ * @return Total number of nodes in all the trees
+ */
+ public long nbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum;
+ }
+
+ /**
+ * @return Mean maximum depth per tree
+ */
+ public long meanMaxDepth() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.maxDepth();
+ }
+
+ return sum / trees.size();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof DecisionForest)) {
+ return false;
+ }
+
+ DecisionForest rf = (DecisionForest) obj;
+
+ return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees());
+ }
+
+ @Override
+ public int hashCode() {
+ return trees.hashCode();
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(trees.size());
+ for (Node tree : trees) {
+ tree.write(dataOutput);
+ }
+ }
+
+ /**
+ * Reads the trees from the input and adds them to the existing trees
+ */
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ int size = dataInput.readInt();
+ for (int i = 0; i < size; i++) {
+ trees.add(Node.read(dataInput));
+ }
+ }
+
+ /**
+ * Read the forest from inputStream
+ * @param dataInput - input forest
+ * @return {@link org.apache.mahout.classifier.df.DecisionForest}
+ * @throws IOException
+ */
+ public static DecisionForest read(DataInput dataInput) throws IOException {
+ DecisionForest forest = new DecisionForest();
+ forest.readFields(dataInput);
+ return forest;
+ }
+
+ /**
+ * Load the forest from a single file or a directory of files
+ * @throws java.io.IOException
+ */
+ public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
+ FileSystem fs = forestPath.getFileSystem(conf);
+ Path[] files;
+ if (fs.getFileStatus(forestPath).isDir()) {
+ files = DFUtils.listOutputFiles(fs, forestPath);
+ } else {
+ files = new Path[]{forestPath};
+ }
+
+ DecisionForest forest = null;
+ for (Path path : files) {
+ try (FSDataInputStream dataInput = new FSDataInputStream(fs.open(path))) {
+ if (forest == null) {
+ forest = read(dataInput);
+ } else {
+ forest.readFields(dataInput);
+ }
+ }
+ }
+
+ return forest;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
new file mode 100644
index 0000000..13cd386
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Various methods to compute from the output of a random forest
+ */
+@Deprecated
+public final class ErrorEstimate {
+
+ private ErrorEstimate() {
+ }
+
+ public static double errorRate(double[] labels, double[] predictions) {
+ Preconditions.checkArgument(labels.length == predictions.length, "labels.length != predictions.length");
+ double nberrors = 0; // number of instance that got bad predictions
+ double datasize = 0; // number of classified instances
+
+ for (int index = 0; index < labels.length; index++) {
+ if (predictions[index] == -1) {
+ continue; // instance not classified
+ }
+
+ if (predictions[index] != labels[index]) {
+ nberrors++;
+ }
+
+ datasize++;
+ }
+
+ return nberrors / datasize;
+ }
+
+}