You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:11 UTC
[09/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
new file mode 100644
index 0000000..ca4d2b2
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
@@ -0,0 +1,236 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Test;
+
+/** <p>Tests {@link EuclideanDistanceSimilarity}.</p> */
+public final class EuclideanDistanceSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(1.0, correlation, EPSILON);
+ }
+
+ @Test
+ public void testNoCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.1639607805437114, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.7213202601812372, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {90.0, 80.0, 70.0},
+ {70.0, 80.0, 90.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.05770363219029305, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.2843646522044218, correlation);
+ }
+
+ @Test
+ public void testSimpleWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.8210911630511055, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {-2.0, -2.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertEquals(1.0, correlation, EPSILON);
+ }
+
+ @Test
+ public void testNoItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -3.0},
+ {-2.0, 2.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.1639607805437114, correlation);
+ }
+
+ @Test
+ public void testNoItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).itemSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {90.0, 70.0},
+ {80.0, 80.0},
+ {70.0, 90.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.05770363219029305, correlation);
+ }
+
+ @Test
+ public void testSimpleItem() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.2843646522044218, correlation);
+ }
+
+ @Test
+ public void testSimpleItemWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity itemSimilarity = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED);
+ double correlation = itemSimilarity.itemSimilarity(0, 1);
+ assertCorrelationEquals(0.8210911630511055, correlation);
+ }
+
+ @Test
+ public void testRefresh() throws TasteException {
+ // Make sure this doesn't throw an exception
+ new EuclideanDistanceSimilarity(getDataModel()).refresh(null);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
new file mode 100644
index 0000000..5ce255c
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** <p>Tests {@link GenericItemSimilarity}.</p> */
+public final class GenericItemSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testSimple() {
+ List<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 1, 0.6));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 1, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.3));
+ GenericItemSimilarity itemCorrelation = new GenericItemSimilarity(similarities);
+ assertEquals(1.0, itemCorrelation.itemSimilarity(1, 1), EPSILON);
+ assertEquals(0.6, itemCorrelation.itemSimilarity(1, 2), EPSILON);
+ assertEquals(0.6, itemCorrelation.itemSimilarity(2, 1), EPSILON);
+ assertEquals(0.3, itemCorrelation.itemSimilarity(1, 3), EPSILON);
+ assertTrue(Double.isNaN(itemCorrelation.itemSimilarity(3, 4)));
+ }
+
+ @Test
+ public void testFromCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity otherSimilarity = new PearsonCorrelationSimilarity(dataModel);
+ ItemSimilarity itemSimilarity = new GenericItemSimilarity(otherSimilarity, dataModel);
+ assertCorrelationEquals(1.0, itemSimilarity.itemSimilarity(0, 0));
+ assertCorrelationEquals(0.960768922830523, itemSimilarity.itemSimilarity(0, 1));
+ }
+
+ @Test
+ public void testAllSimilaritiesWithoutIndex() throws TasteException {
+
+ List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities =
+ Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2));
+
+ ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities);
+
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L));
+ }
+
+ @Test
+ public void testAllSimilaritiesWithIndex() throws TasteException {
+
+ List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities =
+ Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2));
+
+ ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities);
+
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L));
+ }
+
+ private static boolean containsExactly(long[] allIDs, long... shouldContainID) {
+ return new FastIDSet(allIDs).intersectionSize(new FastIDSet(shouldContainID)) == shouldContainID.length;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
new file mode 100644
index 0000000..ae9df5c
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link LogLikelihoodSimilarity}.</p> */
+public final class LogLikelihoodSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4, 5},
+ new Double[][] {
+ {1.0, 1.0},
+ {1.0, null, 1.0},
+ {null, null, 1.0, 1.0, 1.0},
+ {1.0, 1.0, 1.0, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0, 1.0},
+ });
+
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(0, 1));
+
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(1, 2));
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(2, 1));
+
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(3, 2));
+
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(3, 4));
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(4, 3));
+ }
+
+ @Test
+ public void testNoSimilarity() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {1.0, null, 1.0, 1.0},
+ {1.0, null, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ });
+
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(0, 1));
+
+ assertCorrelationEquals(0.0, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(0.0, similarity.itemSimilarity(3, 2));
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new LogLikelihoodSimilarity(getDataModel()).refresh(null);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
new file mode 100644
index 0000000..bb3ad3e
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
@@ -0,0 +1,265 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.junit.Test;
+
+/** <p>Tests {@link PearsonCorrelationSimilarity}.</p> */
+public final class PearsonCorrelationSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ // Yeah, undefined in this case
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {90.0, 80.0, 70.0},
+ {70.0, 80.0, 90.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.9607689228305227, correlation);
+ }
+
+ @Test
+ public void testSimpleWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.9901922307076306, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {-2.0, -2.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ // Yeah, undefined in this case
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -3.0},
+ {2.0, -2.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {90.0, 70.0},
+ {80.0, 80.0},
+ {70.0, 90.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimpleItem() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.9607689228305227, correlation);
+ }
+
+ @Test
+ public void testSimpleItemWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity itemSimilarity = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED);
+ double correlation = itemSimilarity.itemSimilarity(0, 1);
+ assertCorrelationEquals(0.9901922307076306, correlation);
+ }
+
+ @Test
+ public void testRefresh() throws Exception {
+ // Make sure this doesn't throw an exception
+ new PearsonCorrelationSimilarity(getDataModel()).refresh(null);
+ }
+
+ @Test
+ public void testInferrer() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, 2.0, null, null, 6.0},
+ {1.0, 8.0, null, 3.0, 4.0, null},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ similarity.setPreferenceInferrer(new PreferenceInferrer() {
+ @Override
+ public float inferPreference(long userID, long itemID) {
+ return 1.0f;
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ }
+ });
+
+ assertEquals(-0.435285750066007, similarity.userSimilarity(1L, 2L), EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
new file mode 100644
index 0000000..ad1e4b7
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+
+abstract class SimilarityTestCase extends TasteTestCase {
+
+ static void assertCorrelationEquals(double expected, double actual) {
+ if (Double.isNaN(expected)) {
+ assertTrue("Correlation is not NaN", Double.isNaN(actual));
+ } else {
+ assertTrue("Correlation is NaN", !Double.isNaN(actual));
+ assertTrue("Correlation > 1.0", actual <= 1.0);
+ assertTrue("Correlation < -1.0", actual >= -1.0);
+ assertEquals(expected, actual, EPSILON);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
new file mode 100644
index 0000000..6034f4b
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link SpearmanCorrelationSimilarity}.</p> */
+public final class SpearmanCorrelationSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {1.0, 2.0, 3.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testAnticorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {3.0, 2.0, 1.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 3.0, 1.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-0.5, correlation);
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new SpearmanCorrelationSimilarity(getDataModel()).refresh(null);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
new file mode 100644
index 0000000..87f82b9
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link TanimotoCoefficientSimilarity}.</p> */
+public final class TanimotoCoefficientSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testNoCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(Double.NaN, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.3333333333333333, correlation);
+ }
+
+ @Test
+ public void testCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0},
+ {1.0, 1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(0.3333333333333333, correlation, EPSILON);
+ }
+
+ @Test
+ public void testCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0, 1.0},
+ {1.0, 1.0, null, 0.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(0.5, correlation, EPSILON);
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new TanimotoCoefficientSimilarity(getDataModel()).refresh(null);
+ }
+
+ @Test
+ public void testReturnNaNDoubleWhenNoSimilaritiesForTwoItems() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, null, 3.0},
+ {1.0, 1.0, null},
+ });
+ Double similarity = new TanimotoCoefficientSimilarity(dataModel).itemSimilarity(1, 2);
+ assertEquals(Double.NaN, similarity, EPSILON);
+ }
+
+ @Test
+ public void testItemsSimilarities() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {2.0, null, 2.0},
+ {1.0, 1.0, 1.0},
+ });
+ TanimotoCoefficientSimilarity tCS = new TanimotoCoefficientSimilarity(dataModel);
+ assertEquals(0.5, tCS.itemSimilarity(0, 1), EPSILON);
+ assertEquals(1, tCS.itemSimilarity(0, 2), EPSILON);
+
+ double[] similarities = tCS.itemSimilarities(0, new long [] {1, 2});
+ assertEquals(0.5, similarities[0], EPSILON);
+ assertEquals(1, similarities[1], EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
new file mode 100644
index 0000000..d9d28ab
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.file;
+
+import java.io.File;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity.ItemItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Before;
+import org.junit.Test;
+
+/** <p>Tests {@link FileItemSimilarity}.</p> */
+public final class FileItemSimilarityTest extends TasteTestCase {
+
+ private static final String[] data = {
+ "1,5,0.125",
+ "1,7,0.5" };
+
+ private static final String[] changedData = {
+ "1,5,0.125",
+ "1,7,0.9",
+ "7,8,0.112" };
+
+ private File testFile;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ testFile = getTestTempFile("test.txt");
+ writeLines(testFile, data);
+ }
+
+ @Test
+ public void testLoadFromFile() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile);
+
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+
+ double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L });
+ assertNotNull(valuesForOne);
+ assertEquals(2, valuesForOne.length);
+ assertEquals(0.125, valuesForOne[0], EPSILON);
+ assertEquals(0.5, valuesForOne[1], EPSILON);
+ }
+
+ @Test
+ public void testNoRefreshAfterFileUpdate() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded*/
+ similarity.itemSimilarity(1L, 5L);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, changedData);
+
+ /* we shouldn't see any changes in the data as we have not yet refreshed */
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+ }
+
+ @Test
+ public void testRefreshAfterFileUpdate() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded */
+ similarity.itemSimilarity(1L, 5L);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, changedData);
+
+ similarity.refresh(null);
+
+ /* we should now see the changes in the data */
+ assertEquals(0.9, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.9, similarity.itemSimilarity(7L, 1L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+
+ assertFalse(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+ assertEquals(0.112, similarity.itemSimilarity(7L, 8L), EPSILON);
+ assertEquals(0.112, similarity.itemSimilarity(8L, 7L), EPSILON);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testFileNotFoundExceptionForNonExistingFile() throws Exception {
+ new FileItemSimilarity(new File("xKsdfksdfsdf"));
+ }
+
+ @Test
+ public void testFileItemItemSimilarityIterable() throws Exception {
+ Iterable<ItemItemSimilarity> similarityIterable = new FileItemItemSimilarityIterable(testFile);
+ GenericItemSimilarity similarity = new GenericItemSimilarity(similarityIterable);
+
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+
+ double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L });
+ assertNotNull(valuesForOne);
+ assertEquals(2, valuesForOne.length);
+ assertEquals(0.125, valuesForOne[0], EPSILON);
+ assertEquals(0.5, valuesForOne[1], EPSILON);
+ }
+
+ @Test
+ public void testToString() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile);
+ assertTrue(!similarity.toString().isEmpty());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
new file mode 100644
index 0000000..868e41a
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.precompute;
+
+import java.io.IOException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.TanimotoCoefficientSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class MultithreadedBatchItemSimilaritiesTest {
+
+ @Test
+ public void lessItemsThanBatchSize() throws Exception {
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<>();
+ userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1),
+ new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1))));
+ userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1),
+ new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1))));
+
+ DataModel dataModel = new GenericDataModel(userData);
+ ItemBasedRecommender recommender =
+ new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel));
+
+ BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10);
+
+ batchSimilarities.computeItemSimilarities(1, 1, mock(SimilarItemsWriter.class));
+ }
+
+ @Test(expected = IOException.class)
+ public void higherDegreeOfParallelismThanBatches() throws Exception {
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<>();
+ userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1),
+ new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1))));
+ userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1),
+ new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1))));
+
+ DataModel dataModel = new GenericDataModel(userData);
+ ItemBasedRecommender recommender =
+ new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel));
+
+ BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10);
+
+ // Batch size is 100, so we only get 1 batch from 3 items, but we use a degreeOfParallelism of 2
+ batchSimilarities.computeItemSimilarities(2, 1, mock(SimilarItemsWriter.class));
+ fail();
+ }
+
+ @Test
+ public void testCorrectNumberOfOutputSimilarities() throws Exception {
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<>();
+ userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1),
+ new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1))));
+ userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1),
+ new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1))));
+
+ DataModel dataModel = new GenericDataModel(userData);
+ ItemBasedRecommender recommender =
+ new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel));
+
+ BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10, 2);
+
+ int numOutputSimilarities = batchSimilarities.computeItemSimilarities(2, 1, mock(SimilarItemsWriter.class));
+ assertEquals(numOutputSimilarities, 10);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java
new file mode 100644
index 0000000..afce3cf
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.recommender.GenericRecommendedItem;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.hamcrest.Matchers;
+import org.junit.Test;
+
+public class SimilarItemsTest extends TasteTestCase {
+
+ @Test
+ public void testIterator() {
+ List<RecommendedItem> recommendedItems = new ArrayList<>();
+ for (long itemId = 2; itemId < 10; itemId++) {
+ recommendedItems.add(new GenericRecommendedItem(itemId, itemId));
+ }
+
+ SimilarItems similarItems = new SimilarItems(1, recommendedItems);
+
+ assertThat(similarItems.getSimilarItems(), Matchers.<SimilarItem> iterableWithSize(recommendedItems.size()));
+
+ int byHandIndex = 0;
+ for (SimilarItem simItem : similarItems.getSimilarItems()) {
+ RecommendedItem recItem = recommendedItems.get(byHandIndex++);
+ assertEquals(simItem.getItemID(), recItem.getItemID());
+ assertEquals(simItem.getSimilarity(), recItem.getValue(), EPSILON);
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
new file mode 100644
index 0000000..f037209
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+/**
+ * Class containing sample docs from ASF websites under mahout, lucene and spamassasin projects
+ *
+ */
+public final class ClassifierData {
+
+ public static final String[][] DATA = {
+ {
+ "mahout",
+ "Mahout's goal is to build scalable machine learning libraries. With scalable we mean: "
+ + "Scalable to reasonably large data sets. Our core algorithms for clustering,"
+ + " classfication and batch based collaborative filtering are implemented on top "
+ + "of Apache Hadoop using the map/reduce paradigm. However we do not restrict "
+ + "contributions to Hadoop based implementations: Contributions that run on"},
+ {
+ "mahout",
+ " a single node or on a non-Hadoop cluster are welcome as well. The core"
+ + " libraries are highly optimized to allow for good performance also for"
+ + " non-distributed algorithms. Scalable to support your business case. "
+ + "Mahout is distributed under a commercially friendly Apache Software license. "
+ + "Scalable community. The goal of Mahout is to build a vibrant, responsive, "},
+ {
+ "mahout",
+ "diverse community to facilitate discussions not only on the project itself"
+ + " but also on potential use cases. Come to the mailing lists to find out more."
+ + " Currently Mahout supports mainly four use cases: Recommendation mining takes "
+ + "users' behavior and from that tries to find items users might like. Clustering "},
+ {
+ "mahout",
+ "takes e.g. text documents and groups them into groups of topically related documents."
+ + " Classification learns from exisiting categorized documents what documents of"
+ + " a specific category look like and is able to assign unlabelled documents to "
+ + "the (hopefully) correct category. Frequent itemset mining takes a set of item"
+ + " groups (terms in a query session, shopping cart content) and identifies, which"
+ + " individual items usually appear together."},
+ {
+ "lucene",
+ "Apache Lucene is a high-performance, full-featured text search engine library"
+ + " written entirely in Java. It is a technology suitable for nearly any application "
+ + "that requires full-text search, especially cross-platform. Apache Lucene is an open source"
+ + " project available for free download. Please use the links on the left to access Lucene. "
+ + "The new version is mostly a cleanup release without any new features. "},
+ {
+ "lucene",
+ "All deprecations targeted to be removed in version 3.0 were removed. If you "
+ + "are upgrading from version 2.9.1 of Lucene, you have to fix all deprecation warnings"
+ + " in your code base to be able to recompile against this version. This is the first Lucene"},
+ {
+ "lucene",
+ " release with Java 5 as a minimum requirement. The API was cleaned up to make use of Java 5's "
+ + "generics, varargs, enums, and autoboxing. New users of Lucene are advised to use this version "
+ + "for new developments, because it has a clean, type safe new API. Upgrading users can now remove"},
+ {
+ "lucene",
+ " unnecessary casts and add generics to their code, too. If you have not upgraded your installation "
+ + "to Java 5, please read the file JRE_VERSION_MIGRATION.txt (please note that this is not related to"
+ + " Lucene 3.0, it will also happen with any previous release when you upgrade your Java environment)."},
+ {
+ "spamassasin",
+ "SpamAssassin is a mail filter to identify spam. It is an intelligent email filter which uses a diverse "
+ + "range of tests to identify unsolicited bulk email, more commonly known as Spam. These tests are applied "
+ + "to email headers and content to classify email using advanced statistical methods. In addition, "},
+ {
+ "spamassasin",
+ "SpamAssassin has a modular architecture that allows other technologies to be quickly wielded against spam"
+ + " and is designed for easy integration into virtually any email system."
+ + "SpamAssassin's practical multi-technique approach, modularity, and extensibility continue to give it an "},
+ {
+ "spamassasin",
+ "advantage over other anti-spam systems. Due to these advantages, SpamAssassin is widely used in all aspects "
+ + "of email management. You can readily find SpamAssassin in use in both email clients and servers, on many "
+ + "different operating systems, filtering incoming as well as outgoing email, and implementing a "
+ + "very broad range "},
+ {
+ "spamassasin",
+ "of policy actions. These installations include service providers, businesses, not-for-profit and "
+ + "educational organizations, and end-user systems. SpamAssassin also forms the basis for numerous "
+ + "commercial anti-spam products available on the market today."}};
+
+
+ private ClassifierData() { }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
new file mode 100644
index 0000000..3ffff85
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public final class ConfusionMatrixTest extends MahoutTestCase {
+
+ private static final int[][] VALUES = {{2, 3}, {10, 20}};
+ private static final String[] LABELS = {"Label1", "Label2"};
+ private static final int[] OTHER = {3, 6};
+ private static final String DEFAULT_LABEL = "other";
+
+ @Test
+ public void testBuild() {
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ checkValues(confusionMatrix);
+ checkAccuracy(confusionMatrix);
+ }
+
+ @Test
+ public void testGetMatrix() {
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ Matrix m = confusionMatrix.getMatrix();
+ Map<String, Integer> rowLabels = m.getRowLabelBindings();
+
+ assertEquals(confusionMatrix.getLabels().size(), m.numCols());
+ assertTrue(rowLabels.keySet().contains(LABELS[0]));
+ assertTrue(rowLabels.keySet().contains(LABELS[1]));
+ assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+ assertEquals(2, confusionMatrix.getCorrect(LABELS[0]));
+ assertEquals(20, confusionMatrix.getCorrect(LABELS[1]));
+ assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL));
+ }
+
+ /**
+ * Example taken from
+ * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
+ */
+ @Test
+ public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
+ Collection<String> labelList = Arrays.asList("0", "1", "2");
+
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT");
+ confusionMatrix.putCount("0", "0", 2);
+ confusionMatrix.putCount("1", "0", 1);
+ confusionMatrix.putCount("1", "2", 1);
+ confusionMatrix.putCount("2", "1", 2);
+
+ double delta = 0.001;
+ assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta);
+ assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta);
+ assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta);
+ }
+
+ private static void checkValues(ConfusionMatrix cm) {
+ int[][] counts = cm.getConfusionMatrix();
+ cm.toString();
+ assertEquals(counts.length, counts[0].length);
+ assertEquals(3, counts.length);
+ assertEquals(VALUES[0][0], counts[0][0]);
+ assertEquals(VALUES[0][1], counts[0][1]);
+ assertEquals(VALUES[1][0], counts[1][0]);
+ assertEquals(VALUES[1][1], counts[1][1]);
+ assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
+ assertEquals(OTHER[0], counts[0][2]);
+ assertEquals(OTHER[1], counts[1][2]);
+ assertEquals(3, cm.getLabels().size());
+ assertTrue(cm.getLabels().contains(LABELS[0]));
+ assertTrue(cm.getLabels().contains(LABELS[1]));
+ assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
+ }
+
+ private static void checkAccuracy(ConfusionMatrix cm) {
+ Collection<String> labelstrs = cm.getLabels();
+ assertEquals(3, labelstrs.size());
+ assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON);
+ assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON);
+ assertTrue(Double.isNaN(cm.getAccuracy("other")));
+ }
+
+ private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) {
+ Collection<String> labelList = Lists.newArrayList();
+ labelList.add(labels[0]);
+ labelList.add(labels[1]);
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel);
+
+ confusionMatrix.putCount("Label1", "Label1", values[0][0]);
+ confusionMatrix.putCount("Label1", "Label2", values[0][1]);
+ confusionMatrix.putCount("Label2", "Label1", values[1][0]);
+ confusionMatrix.putCount("Label2", "Label2", values[1][1]);
+ confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+ confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
+ return confusionMatrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
new file mode 100644
index 0000000..86234f8
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class RegressionResultAnalyzerTest extends MahoutTestCase {
+
+ private static final Pattern p1 = Pattern.compile("Correlation coefficient *: *(.*)\n");
+ private static final Pattern p2 = Pattern.compile("Mean absolute error *: *(.*)\n");
+ private static final Pattern p3 = Pattern.compile("Root mean squared error *: *(.*)\n");
+ private static final Pattern p4 = Pattern.compile("Predictable Instances *: *(.*)\n");
+ private static final Pattern p5 = Pattern.compile("Unpredictable Instances *: *(.*)\n");
+ private static final Pattern p6 = Pattern.compile("Total Regressed Instances *: *(.*)\n");
+
+ private static double[] parseAnalysis(CharSequence analysis) {
+ double[] results = new double[3];
+ Matcher m = p1.matcher(analysis);
+ if (m.find()) {
+ results[0] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p2.matcher(analysis);
+ if (m.find()) {
+ results[1] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p3.matcher(analysis);
+ if (m.find()) {
+ results[2] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ return results;
+ }
+
+ private static int[] parseAnalysisCount(CharSequence analysis) {
+ int[] results = new int[3];
+ Matcher m = p4.matcher(analysis);
+ if (m.find()) {
+ results[0] = Integer.parseInt(m.group(1));
+ }
+ m = p5.matcher(analysis);
+ if (m.find()) {
+ results[1] = Integer.parseInt(m.group(1));
+ }
+ m = p6.matcher(analysis);
+ if (m.find()) {
+ results[2] = Integer.parseInt(m.group(1));
+ }
+ return results;
+ }
+
+ @Test
+ public void testAnalyze() {
+ double[][] results = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = i + 1;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertArrayEquals(new double[]{1.0, 1.0, 1.0}, parseAnalysis(analysis), 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{0.9573, 2.5694, 3.2848}, parseAnalysis(analysis), 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = results.length - i;
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{-0.9573, 4.1351, 5.1573}, parseAnalysis(analysis), 0);
+ }
+
+ @Test
+ public void testUnpredictable() {
+ double[][] results = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = Double.NaN;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertNull(parseAnalysis(analysis));
+ assertArrayEquals(new int[]{0, 10, 10}, parseAnalysisCount(analysis));
+
+ for (int i = 0; i < results.length - 3; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{0.9552, 1.4526, 1.9345}, parseAnalysis(analysis), 0);
+ assertArrayEquals(new int[]{7, 3, 10}, parseAnalysisCount(analysis));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
new file mode 100644
index 0000000..036d473
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+@Deprecated
+public final class DecisionForestTest extends MahoutTestCase {
+
+ private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
+ "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+ "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+ "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+ "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+ "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+ "rainy,71,91,TRUE,no"};
+
+ private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
+ "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+
+ private Random rng;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ }
+
+ private static Data[] generateTrainingDataA() throws DescriptorException {
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+
+ // Training data
+ Data data = DataLoader.loadData(dataset, TRAIN_DATA);
+ @SuppressWarnings("unchecked")
+ List<Instance>[] instances = new List[3];
+ for (int i = 0; i < instances.length; i++) {
+ instances[i] = Lists.newArrayList();
+ }
+ for (int i = 0; i < data.size(); i++) {
+ if (data.get(i).get(0) == 0.0d) {
+ instances[0].add(data.get(i));
+ } else {
+ instances[1].add(data.get(i));
+ }
+ }
+ Data[] datas = new Data[instances.length];
+ for (int i = 0; i < datas.length; i++) {
+ datas[i] = new Data(dataset, instances[i]);
+ }
+
+ return datas;
+ }
+
+ private static Data[] generateTrainingDataB() throws DescriptorException {
+
+ // Training data
+ String[] trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 3 == 0) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ } else if (i % 3 == 1) {
+ trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+ } else {
+ trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+ }
+ }
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+ Data[] datas = new Data[3];
+ datas[0] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 2 == 0) {
+ trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+ } else {
+ trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+ }
+ }
+ datas[1] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[10];
+ for (int i = 0; i < trainData.length; i++) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ }
+ datas[2] = DataLoader.loadData(dataset, trainData);
+
+ return datas;
+ }
+
+ private DecisionForest buildForest(Data[] datas) {
+ List<Node> trees = Lists.newArrayList();
+ for (Data data : datas) {
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ builder.setMinSplitNum(0);
+ builder.setComplemented(false);
+ trees.add(builder.build(rng, data));
+ }
+ return new DecisionForest(trees);
+ }
+
+ @Test
+ public void testClassify() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Dataset dataset = datas[0].getDataset();
+ Data testData = DataLoader.loadData(dataset, TEST_DATA);
+
+ double noValue = dataset.valueOf(4, "no");
+ double yesValue = dataset.valueOf(4, "yes");
+ assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON);
+ // This one is tie-broken -- 1 is OK too
+ //assertEquals(yesValue, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON);
+ assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON);
+ }
+
+ @Test
+ public void testClassifyData() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Dataset dataset = datas[0].getDataset();
+ Data testData = DataLoader.loadData(dataset, TEST_DATA);
+
+ double[][] predictions = new double[testData.size()][];
+ forest.classify(testData, predictions);
+ double noValue = dataset.valueOf(4, "no");
+ double yesValue = dataset.valueOf(4, "yes");
+ assertArrayEquals(new double[][]{{noValue, Double.NaN, Double.NaN},
+ {noValue, yesValue, Double.NaN}, {noValue, noValue, Double.NaN}}, predictions);
+ }
+
+ @Test
+ public void testRegression() throws DescriptorException {
+ Data[] datas = generateTrainingDataB();
+ DecisionForest[] forests = new DecisionForest[datas.length];
+ for (int i = 0; i < datas.length; i++) {
+ Data[] subDatas = new Data[datas.length - 1];
+ int k = 0;
+ for (int j = 0; j < datas.length; j++) {
+ if (j != i) {
+ subDatas[k] = datas[j];
+ k++;
+ }
+ }
+ forests[i] = buildForest(subDatas);
+ }
+
+ double[][] predictions = new double[datas[0].size()][];
+ forests[0].classify(datas[0], predictions);
+ assertArrayEquals(new double[]{20.0, 20.0}, predictions[0], EPSILON);
+ assertArrayEquals(new double[]{39.0, 29.0}, predictions[1], EPSILON);
+ assertArrayEquals(new double[]{Double.NaN, 29.0}, predictions[2], EPSILON);
+ assertArrayEquals(new double[]{Double.NaN, 23.0}, predictions[17], EPSILON);
+
+ predictions = new double[datas[1].size()][];
+ forests[1].classify(datas[1], predictions);
+ assertArrayEquals(new double[]{30.0, 29.0}, predictions[19], EPSILON);
+
+ predictions = new double[datas[2].size()][];
+ forests[2].classify(datas[2], predictions);
+ assertArrayEquals(new double[]{29.0, 28.0}, predictions[9], EPSILON);
+
+ assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), EPSILON);
+ assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), EPSILON);
+ assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), EPSILON);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
new file mode 100644
index 0000000..56b4787
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import java.lang.reflect.Method;
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+@Deprecated
+public final class DecisionTreeBuilderTest extends MahoutTestCase {
+
+ /**
+ * make sure that DecisionTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+ * selected yet
+ */
+ @Test
+ public void testRandomAttributes() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ int nbAttributes = rng.nextInt(100) + 1;
+ boolean[] selected = new boolean[nbAttributes];
+
+ for (int nloop = 0; nloop < 100; nloop++) {
+ Arrays.fill(selected, false);
+
+ // randomly select some attributes
+ int nbSelected = rng.nextInt(nbAttributes - 1);
+ for (int index = 0; index < nbSelected; index++) {
+ int attr;
+ do {
+ attr = rng.nextInt(nbAttributes);
+ } while (selected[attr]);
+
+ selected[attr] = true;
+ }
+
+ int m = rng.nextInt(nbAttributes);
+
+ Method randomAttributes = DecisionTreeBuilder.class.getDeclaredMethod("randomAttributes",
+ Random.class, boolean[].class, int.class);
+ randomAttributes.setAccessible(true);
+ int[] attrs = (int[]) randomAttributes.invoke(null, rng, selected, m);
+
+ assertNotNull(attrs);
+ assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
+
+ for (int attr : attrs) {
+ // the attribute should not be already selected
+ assertFalse("an attribute has already been selected", selected[attr]);
+
+ // each attribute should be in the range [0, nbAttributes[
+ assertTrue(attr >= 0);
+ assertTrue(attr < nbAttributes);
+
+ // each attribute should appear only once
+ assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
new file mode 100644
index 0000000..87fd44b
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+@Deprecated
+public final class DefaultTreeBuilderTest extends MahoutTestCase {
+
+ /**
+ * make sure that DefaultTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+ * selected yet
+ */
+ @Test
+ public void testRandomAttributes() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ int nbAttributes = rng.nextInt(100) + 1;
+ boolean[] selected = new boolean[nbAttributes];
+
+ for (int nloop = 0; nloop < 100; nloop++) {
+ Arrays.fill(selected, false);
+
+ // randomly select some attributes
+ int nbSelected = rng.nextInt(nbAttributes - 1);
+ for (int index = 0; index < nbSelected; index++) {
+ int attr;
+ do {
+ attr = rng.nextInt(nbAttributes);
+ } while (selected[attr]);
+
+ selected[attr] = true;
+ }
+
+ int m = rng.nextInt(nbAttributes);
+
+ int[] attrs = DefaultTreeBuilder.randomAttributes(rng, selected, m);
+
+ assertNotNull(attrs);
+ assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
+
+ for (int attr : attrs) {
+ // the attribute should not be already selected
+ assertFalse("an attribute has already been selected", selected[attr]);
+
+ // each attribute should be in the range [0, nbAttributes[
+ assertTrue(attr >= 0);
+ assertTrue(attr < nbAttributes);
+
+ // each attribute should appear only once
+ assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
new file mode 100644
index 0000000..8ebc721
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.junit.Test;
+
+import java.util.Random;
+@Deprecated
+public final class InfiniteRecursionTest extends MahoutTestCase {
+
+ private static final double[][] dData = {
+ { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 4 },
+ { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 3 },
+ { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 4 },
+ { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 3 }
+ };
+
+ /**
+ * make sure DecisionTreeBuilder.build() does not throw a StackOverflowException
+ */
+ @Test
+ public void testBuild() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ String[] source = Utils.double2String(dData);
+ String descriptor = "N N N N N N N N L";
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
+ Data data = DataLoader.loadData(dataset, source);
+ TreeBuilder builder = new DecisionTreeBuilder();
+ builder.build(rng, data);
+
+ // regression
+ dataset = DataLoader.generateDataset(descriptor, true, source);
+ data = DataLoader.loadData(dataset, source);
+ builder = new DecisionTreeBuilder();
+ builder.build(rng, data);
+ }
+}