You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:10 UTC
[08/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
new file mode 100644
index 0000000..dfae61d
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+@Deprecated
+public final class DataConverterTest extends MahoutTestCase {
+
+ private static final int ATTRIBUTE_COUNT = 10;
+
+ private static final int INSTANCE_COUNT = 100;
+
+ @Test
+ public void testConvert() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, INSTANCE_COUNT);
+ String[] sData = Utils.double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ DataConverter converter = new DataConverter(dataset);
+
+ for (int index = 0; index < data.size(); index++) {
+ assertEquals(data.get(index), converter.convert(sData[index]));
+ }
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, true, INSTANCE_COUNT);
+ sData = Utils.double2String(source);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+ data = DataLoader.loadData(dataset, sData);
+
+ converter = new DataConverter(dataset);
+
+ for (int index = 0; index < data.size(); index++) {
+ assertEquals(data.get(index), converter.convert(sData[index]));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
new file mode 100644
index 0000000..aeb69fc
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
@@ -0,0 +1,350 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Collection;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.junit.Test;
+@Deprecated
+public final class DataLoaderTest extends MahoutTestCase {
+
+ private Random rng;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ }
+
+ @Test
+ public void testLoadDataWithDescriptor() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(data, attrs, missings);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data loaded = DataLoader.loadData(dataset, sData);
+
+ testLoadedData(data, attrs, missings, loaded);
+ testLoadedDataset(data, attrs, missings, loaded);
+
+ // regression
+ data = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(data, attrs, missings);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+ loaded = DataLoader.loadData(dataset, sData);
+
+ testLoadedData(data, attrs, missings, loaded);
+ testLoadedDataset(data, attrs, missings, loaded);
+ }
+
+ /**
+ * Test method for
+ * {@link DataLoader#generateDataset(CharSequence, boolean, String[])}
+ */
+ @Test
+ public void testGenerateDataset() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(data, attrs, missings);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+
+ assertEquals(expected, dataset);
+
+ // regression
+ data = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(data, attrs, missings);
+ expected = DataLoader.generateDataset(descriptor, true, sData);
+
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+
+ assertEquals(expected, dataset);
+}
+
+ /**
+ * Converts the data to an array of comma-separated strings and adds some
+ * missing values in all but IGNORED attributes
+ *
+ * @param missings indexes of vectors with missing values
+ */
+ private String[] prepareData(double[][] data, Attribute[] attrs, Collection<Integer> missings) {
+ int nbAttributes = attrs.length;
+
+ String[] sData = new String[data.length];
+
+ for (int index = 0; index < data.length; index++) {
+ int missingAttr;
+ if (rng.nextDouble() < 0.0) {
+ // add a missing value
+ missings.add(index);
+
+ // choose a random attribute (not IGNORED)
+ do {
+ missingAttr = rng.nextInt(nbAttributes);
+ } while (attrs[missingAttr].isIgnored());
+ } else {
+ missingAttr = -1;
+ }
+
+ StringBuilder builder = new StringBuilder();
+
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attr == missingAttr) {
+ // add a missing value here
+ builder.append('?').append(',');
+ } else {
+ builder.append(data[index][attr]).append(',');
+ }
+ }
+
+ sData[index] = builder.toString();
+ }
+
+ return sData;
+ }
+
+ /**
+ * Test if the loaded data matches the source data
+ *
+ * @param missings indexes of instance with missing values
+ */
+ static void testLoadedData(double[][] data, Attribute[] attrs, Collection<Integer> missings, Data loaded) {
+ int nbAttributes = attrs.length;
+
+ // check the vectors
+ assertEquals("number of instance", data.length - missings.size(), loaded .size());
+
+ // make sure that the attributes are loaded correctly
+ int lind = 0;
+ for (int index = 0; index < data.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }// this vector won't be loaded
+
+ double[] vector = data[index];
+ Instance instance = loaded.get(lind);
+
+ int aId = 0;
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attrs[attr].isIgnored()) {
+ continue;
+ }
+
+ if (attrs[attr].isNumerical()) {
+ assertEquals(vector[attr], instance.get(aId), EPSILON);
+ aId++;
+ } else if (attrs[attr].isCategorical()) {
+ checkCategorical(data, missings, loaded, attr, aId, vector[attr],
+ instance.get(aId));
+ aId++;
+ } else if (attrs[attr].isLabel()) {
+ if (loaded.getDataset().isNumerical(aId)) {
+ assertEquals(vector[attr], instance.get(aId), EPSILON);
+ } else {
+ checkCategorical(data, missings, loaded, attr, aId, vector[attr],
+ instance.get(aId));
+ }
+ aId++;
+ }
+ }
+
+ lind++;
+ }
+
+ }
+
+ /**
+ * Test if the loaded dataset matches the source data
+ *
+ * @param missings indexes of instance with missing values
+ */
+ static void testLoadedDataset(double[][] data,
+ Attribute[] attrs,
+ Collection<Integer> missings,
+ Data loaded) {
+ int nbAttributes = attrs.length;
+
+ int iId = 0;
+ for (int index = 0; index < data.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }
+
+ Instance instance = loaded.get(iId++);
+
+ int aId = 0;
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attrs[attr].isIgnored()) {
+ continue;
+ }
+
+ if (attrs[attr].isLabel()) {
+ if (!loaded.getDataset().isNumerical(aId)) {
+ double nValue = instance.get(aId);
+ String oValue = Double.toString(data[index][attr]);
+ assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+ }
+ } else {
+ assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId));
+
+ if (attrs[attr].isCategorical()) {
+ double nValue = instance.get(aId);
+ String oValue = Double.toString(data[index][attr]);
+ assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+ }
+ }
+ aId++;
+ }
+ }
+
+ }
+
+ @Test
+ public void testLoadDataFromFile() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(source, attrs, missings);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+
+ Path dataPath = Utils.writeDataToTestFile(sData);
+ FileSystem fs = dataPath.getFileSystem(getConfiguration());
+ Data loaded = DataLoader.loadData(dataset, fs, dataPath);
+
+ testLoadedData(source, attrs, missings, loaded);
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(source, attrs, missings);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+
+ dataPath = Utils.writeDataToTestFile(sData);
+ fs = dataPath.getFileSystem(getConfiguration());
+ loaded = DataLoader.loadData(dataset, fs, dataPath);
+
+ testLoadedData(source, attrs, missings, loaded);
+}
+
+ /**
+ * Test method for
+ * {@link DataLoader#generateDataset(CharSequence, boolean, FileSystem, Path)}
+ */
+ @Test
+ public void testGenerateDatasetFromFile() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(source, attrs, missings);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ Path path = Utils.writeDataToTestFile(sData);
+ FileSystem fs = path.getFileSystem(getConfiguration());
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path);
+
+ assertEquals(expected, dataset);
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(source, attrs, missings);
+ expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ path = Utils.writeDataToTestFile(sData);
+ fs = path.getFileSystem(getConfiguration());
+
+ dataset = DataLoader.generateDataset(descriptor, false, fs, path);
+
+ assertEquals(expected, dataset);
+ }
+
+ /**
+ * each time oValue appears in data for the attribute 'attr', the nValue must
+ * appear in vectors for the same attribute.
+ *
+ * @param attr attribute's index in source
+ * @param aId attribute's index in loaded
+ * @param oValue old value in source
+ * @param nValue new value in loaded
+ */
+ static void checkCategorical(double[][] source,
+ Collection<Integer> missings,
+ Data loaded,
+ int attr,
+ int aId,
+ double oValue,
+ double nValue) {
+ int lind = 0;
+
+ for (int index = 0; index < source.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }
+
+ if (source[index][attr] == oValue) {
+ assertEquals(nValue, loaded.get(lind).get(aId), EPSILON);
+ } else {
+ assertFalse(nValue == loaded.get(lind).get(aId));
+ }
+
+ lind++;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
new file mode 100644
index 0000000..70ed7f6
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
@@ -0,0 +1,396 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.junit.Test;
+@Deprecated
+public class DataTest extends MahoutTestCase {
+
+ private static final int ATTRIBUTE_COUNT = 10;
+
+ private static final int DATA_SIZE = 100;
+
+ private Random rng;
+
+ private Data classifierData;
+
+ private Data regressionData;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ classifierData = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+ regressionData = Utils.randomData(rng, ATTRIBUTE_COUNT, true, DATA_SIZE);
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#subset(org.apache.mahout.classifier.df.data.conditions.Condition)}.
+ */
+ @Test
+ public void testSubset() {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int attr = rng.nextInt(classifierData.getDataset().nbAttributes());
+
+ double[] values = classifierData.values(attr);
+ double value = values[rng.nextInt(values.length)];
+
+ Data eSubset = classifierData.subset(Condition.equals(attr, value));
+ Data lSubset = classifierData.subset(Condition.lesser(attr, value));
+ Data gSubset = classifierData.subset(Condition.greaterOrEquals(attr, value));
+
+ for (int index = 0; index < DATA_SIZE; index++) {
+ Instance instance = classifierData.get(index);
+
+ if (instance.get(attr) < value) {
+ assertTrue(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertFalse(gSubset.contains(instance));
+ } else if (instance.get(attr) == value) {
+ assertFalse(lSubset.contains(instance));
+ assertTrue(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ } else {
+ assertFalse(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ }
+ }
+
+ // regression
+ attr = rng.nextInt(regressionData.getDataset().nbAttributes());
+
+ values = regressionData.values(attr);
+ value = values[rng.nextInt(values.length)];
+
+ eSubset = regressionData.subset(Condition.equals(attr, value));
+ lSubset = regressionData.subset(Condition.lesser(attr, value));
+ gSubset = regressionData.subset(Condition.greaterOrEquals(attr, value));
+
+ for (int index = 0; index < DATA_SIZE; index++) {
+ Instance instance = regressionData.get(index);
+
+ if (instance.get(attr) < value) {
+ assertTrue(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertFalse(gSubset.contains(instance));
+ } else if (instance.get(attr) == value) {
+ assertFalse(lSubset.contains(instance));
+ assertTrue(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ } else {
+ assertFalse(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testValues() throws Exception {
+ for (int attr = 0; attr < classifierData.getDataset().nbAttributes(); attr++) {
+ double[] values = classifierData.values(attr);
+
+ // each value of the attribute should appear exactly one time in values
+ for (int index = 0; index < DATA_SIZE; index++) {
+ assertEquals(1, count(values, classifierData.get(index).get(attr)));
+ }
+ }
+
+ for (int attr = 0; attr < regressionData.getDataset().nbAttributes(); attr++) {
+ double[] values = regressionData.values(attr);
+
+ // each value of the attribute should appear exactly one time in values
+ for (int index = 0; index < DATA_SIZE; index++) {
+ assertEquals(1, count(values, regressionData.get(index).get(attr)));
+ }
+ }
+ }
+
+ private static int count(double[] values, double value) {
+ int count = 0;
+ for (double v : values) {
+ if (v == value) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Test
+ public void testIdenticalTrue() throws Exception {
+ // generate a small data, only to get the dataset
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
+
+ // test empty data
+ Data empty = new Data(dataset);
+ assertTrue(empty.isIdentical());
+
+ // test identical data, except for the labels
+ Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+ Instance model = identical.get(0);
+ for (int index = 1; index < DATA_SIZE; index++) {
+ for (int attr = 0; attr < identical.getDataset().nbAttributes(); attr++) {
+ identical.get(index).set(attr, model.get(attr));
+ }
+ }
+
+ assertTrue(identical.isIdentical());
+ }
+
+ @Test
+ public void testIdenticalFalse() throws Exception {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+
+ // choose a random instance
+ int index = rng.nextInt(DATA_SIZE);
+ Instance instance = data.get(index);
+
+ // change a random attribute
+ int attr = rng.nextInt(data.getDataset().nbAttributes());
+ instance.set(attr, instance.get(attr) + 1);
+
+ assertFalse(data.isIdentical());
+ }
+ }
+
+ @Test
+ public void testIdenticalLabelTrue() throws Exception {
+ // generate a small data, only to get a dataset
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
+
+ // test empty data
+ Data empty = new Data(dataset);
+ assertTrue(empty.identicalLabel());
+
+ // test identical labels
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
+ DATA_SIZE, rng.nextInt());
+ String[] sData = Utils.double2String(source);
+
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ assertTrue(data.identicalLabel());
+ }
+
+ @Test
+ public void testIdenticalLabelFalse() throws Exception {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ int label = Utils.findLabel(descriptor);
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
+ DATA_SIZE, rng.nextInt());
+ // choose a random vector and change its label
+ int index = rng.nextInt(DATA_SIZE);
+ source[index][label]++;
+
+ String[] sData = Utils.double2String(source);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ assertFalse(data.identicalLabel());
+ }
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#bagging(java.util.Random)}.
+ */
+ @Test
+ public void testBagging() {
+ Data bag = classifierData.bagging(rng);
+
+ // the bag should have the same size as the data
+ assertEquals(classifierData.size(), bag.size());
+
+ // at least one element from the data should not be in the bag
+ boolean found = false;
+ for (int index = 0; index < classifierData.size() && !found; index++) {
+ found = !bag.contains(classifierData.get(index));
+ }
+
+ assertTrue("some instances from data should not be in the bag", found);
+
+ // regression
+ bag = regressionData.bagging(rng);
+
+ // the bag should have the same size as the data
+ assertEquals(regressionData.size(), bag.size());
+
+ // at least one element from the data should not be in the bag
+ found = false;
+ for (int index = 0; index < regressionData.size() && !found; index++) {
+ found = !bag.contains(regressionData.get(index));
+ }
+
+ assertTrue("some instances from data should not be in the bag", found);
+}
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#rsplit(java.util.Random, int)}.
+ */
+ @Test
+ public void testRsplit() {
+
+ // rsplit should handle empty subsets
+ Data source = classifierData.clone();
+ Data subset = source.rsplit(rng, 0);
+ assertTrue("subset should be empty", subset.isEmpty());
+ assertEquals("source.size is incorrect", DATA_SIZE, source.size());
+
+ // rsplit should handle full size subsets
+ source = classifierData.clone();
+ subset = source.rsplit(rng, DATA_SIZE);
+ assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
+ assertTrue("source should be empty", source.isEmpty());
+
+ // random case
+ int subsize = rng.nextInt(DATA_SIZE);
+ source = classifierData.clone();
+ subset = source.rsplit(rng, subsize);
+ assertEquals("subset.size is incorrect", subsize, subset.size());
+ assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
+
+ // regression
+ // rsplit should handle empty subsets
+ source = regressionData.clone();
+ subset = source.rsplit(rng, 0);
+ assertTrue("subset should be empty", subset.isEmpty());
+ assertEquals("source.size is incorrect", DATA_SIZE, source.size());
+
+ // rsplit should handle full size subsets
+ source = regressionData.clone();
+ subset = source.rsplit(rng, DATA_SIZE);
+ assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
+ assertTrue("source should be empty", source.isEmpty());
+
+ // random case
+ subsize = rng.nextInt(DATA_SIZE);
+ source = regressionData.clone();
+ subset = source.rsplit(rng, subsize);
+ assertEquals("subset.size is incorrect", subsize, subset.size());
+ assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
+}
+
+ @Test
+ public void testCountLabel() throws Exception {
+ Dataset dataset = classifierData.getDataset();
+ int[] counts = new int[dataset.nblabels()];
+
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ Arrays.fill(counts, 0);
+ classifierData.countLabels(counts);
+
+ for (int index = 0; index < classifierData.size(); index++) {
+ counts[(int) dataset.getLabel(classifierData.get(index))]--;
+ }
+
+ for (int label = 0; label < classifierData.getDataset().nblabels(); label++) {
+ assertEquals("Wrong label 'equals' count", 0, counts[0]);
+ }
+ }
+ }
+
+ @Test
+ public void testMajorityLabel() throws Exception {
+
+ // all instances have the same label
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ int label = Utils.findLabel(descriptor);
+
+ int label1 = rng.nextInt();
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100,
+ label1);
+ String[] sData = Utils.double2String(source);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ int code1 = dataset.labelCode(Double.toString(label1));
+
+ assertEquals(code1, data.majorityLabel(rng));
+
+ // 51/100 vectors have label2
+ int label2 = label1 + 1;
+ int nblabel2 = 51;
+ while (nblabel2 > 0) {
+ double[] vector = source[rng.nextInt(100)];
+ if (vector[label] != label2) {
+ vector[label] = label2;
+ nblabel2--;
+ }
+ }
+ sData = Utils.double2String(source);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ int code2 = dataset.labelCode(Double.toString(label2));
+
+ // label2 should be the majority label
+ assertEquals(code2, data.majorityLabel(rng));
+
+ // 50 vectors with label1 and 50 vectors with label2
+ do {
+ double[] vector = source[rng.nextInt(100)];
+ if (vector[label] == label2) {
+ vector[label] = label1;
+ break;
+ }
+ } while (true);
+ sData = Utils.double2String(source);
+
+ data = DataLoader.loadData(dataset, sData);
+ code1 = dataset.labelCode(Double.toString(label1));
+ code2 = dataset.labelCode(Double.toString(label2));
+
+ // majorityLabel should return label1 and label2 at random
+ boolean found1 = false;
+ boolean found2 = false;
+ for (int index = 0; index < 10 && (!found1 || !found2); index++) {
+ int major = data.majorityLabel(rng);
+ if (major == code1) {
+ found1 = true;
+ }
+ if (major == code2) {
+ found2 = true;
+ }
+ }
+ assertTrue(found1 && found2);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
new file mode 100644
index 0000000..e5c9ee7
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package org.apache.mahout.classifier.df.data;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+@Deprecated
+public final class DatasetTest extends MahoutTestCase {
+
+ @Test
+ public void jsonEncoding() throws DescriptorException {
+ Dataset to = DataLoader.generateDataset("N C I L", true, new String[]{"1 foo 2 3", "4 bar 5 6"});
+
+ // to JSON
+ //assertEquals(json, to.toJSON());
+ assertEquals(3, to.nbAttributes());
+ assertEquals(1, to.getIgnored().length);
+ assertEquals(2, to.getIgnored()[0]);
+ assertEquals(2, to.getLabelId());
+ assertTrue(to.isNumerical(0));
+
+ // from JSON
+ Dataset fromJson = Dataset.fromJSON(to.toJSON());
+ assertEquals(3, fromJson.nbAttributes());
+ assertEquals(1, fromJson.getIgnored().length);
+ assertEquals(2, fromJson.getIgnored()[0]);
+ assertTrue(fromJson.isNumerical(0));
+
+ // read values for a nominal
+ assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo"));
+ }
+
+ @Test
+ public void jsonEncodingIgnoreFeatures() throws DescriptorException {;
+ Dataset to = DataLoader.generateDataset("N C I L", false, new String[]{"1 foo 2 Red", "4 bar 5 Blue"});
+
+ // to JSON
+ //assertEquals(json, to.toJSON());
+ assertEquals(3, to.nbAttributes());
+ assertEquals(1, to.getIgnored().length);
+ assertEquals(2, to.getIgnored()[0]);
+ assertEquals(2, to.getLabelId());
+ assertTrue(to.isNumerical(0));
+ assertNotEquals(to.valueOf(1, "bar"), to.valueOf(1, "foo"));
+ assertNotEquals(to.valueOf(2, "Red"), to.valueOf(2, "Blue"));
+
+ // from JSON
+ Dataset fromJson = Dataset.fromJSON(to.toJSON());
+ assertEquals(3, fromJson.nbAttributes());
+ assertEquals(1, fromJson.getIgnored().length);
+ assertEquals(2, fromJson.getIgnored()[0]);
+ assertTrue(fromJson.isNumerical(0));
+
+ // read values for a nominal, one before and one after the ignore feature
+ assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo"));
+ assertNotEquals(fromJson.valueOf(2, "Red"), fromJson.valueOf(2, "Blue"));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
new file mode 100644
index 0000000..619f067
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.junit.Test;
+@Deprecated
+public final class DescriptorUtilsTest extends MahoutTestCase {
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.DescriptorUtils#parseDescriptor(java.lang.CharSequence)}.
+ */
+ @Test
+ public void testParseDescriptor() throws Exception {
+ int n = 10;
+ int maxnbAttributes = 100;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int nbAttributes = rng.nextInt(maxnbAttributes) + 1;
+
+ char[] tokens = Utils.randomTokens(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(Utils.generateDescriptor(tokens));
+
+ // verify that the attributes matches the token list
+ assertEquals("attributes size", nbAttributes, attrs.length);
+
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ switch (tokens[attr]) {
+ case 'I':
+ assertTrue(attrs[attr].isIgnored());
+ break;
+ case 'N':
+ assertTrue(attrs[attr].isNumerical());
+ break;
+ case 'C':
+ assertTrue(attrs[attr].isCategorical());
+ break;
+ case 'L':
+ assertTrue(attrs[attr].isLabel());
+ break;
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testGenerateDescription() throws Exception {
+ validate("", "");
+ validate("I L C C N N N C", "I L C C N N N C");
+ validate("I L C C N N N C", "I L 2 C 3 N C");
+ validate("I L C C N N N C", " I L 2 C 3 N C ");
+
+ try {
+ validate("", "I L 2 2 C 2 N C");
+ fail("2 consecutive multiplicators");
+ } catch (DescriptorException e) {
+ }
+
+ try {
+ validate("", "I L 2 C -2 N C");
+ fail("negative multiplicator");
+ } catch (DescriptorException e) {
+ }
+ }
+
+ private static void validate(String descriptor, CharSequence description) throws DescriptorException {
+ assertEquals(descriptor, DescriptorUtils.generateDescriptor(description));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
new file mode 100644
index 0000000..9b51ec9
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Random;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.apache.mahout.common.MahoutTestCase;
+
+/**
+ * Helper methods used by the tests
+ *
+ */
+@Deprecated
+public final class Utils {
+
+ private Utils() {}
+
+ /** Used when generating random CATEGORICAL values */
+ private static final int CATEGORICAL_RANGE = 100;
+
+ /**
+ * Generates a random list of tokens
+ * <ul>
+ * <li>each attribute has 50% chance to be NUMERICAL ('N') or CATEGORICAL
+ * ('C')</li>
+ * <li>10% of the attributes are IGNORED ('I')</li>
+ * <li>one randomly chosen attribute becomes the LABEL ('L')</li>
+ * </ul>
+ *
+ * @param rng Random number generator
+ * @param nbTokens number of tokens to generate
+ */
+ public static char[] randomTokens(Random rng, int nbTokens) {
+ char[] result = new char[nbTokens];
+
+ for (int token = 0; token < nbTokens; token++) {
+ double rand = rng.nextDouble();
+ if (rand < 0.1) {
+ result[token] = 'I'; // IGNORED
+ } else if (rand >= 0.5) {
+ result[token] = 'C';
+ } else {
+ result[token] = 'N'; // NUMERICAL
+ } // CATEGORICAL
+ }
+
+ // choose the label
+ result[rng.nextInt(nbTokens)] = 'L';
+
+ return result;
+ }
+
+ /**
+ * Generates a space-separated String that contains all the tokens
+ */
+ public static String generateDescriptor(char[] tokens) {
+ StringBuilder builder = new StringBuilder();
+
+ for (char token : tokens) {
+ builder.append(token).append(' ');
+ }
+
+ return builder.toString();
+ }
+
+ /**
+ * Generates a random descriptor as follows:<br>
+ * <ul>
+ * <li>each attribute has 50% chance to be NUMERICAL or CATEGORICAL</li>
+ * <li>10% of the attributes are IGNORED</li>
+ * <li>one randomly chosen attribute becomes the LABEL</li>
+ * </ul>
+ */
+ public static String randomDescriptor(Random rng, int nbAttributes) {
+ return generateDescriptor(randomTokens(rng, nbAttributes));
+ }
+
+ /**
+ * generates random data based on the given descriptor
+ *
+ * @param rng Random number generator
+ * @param descriptor attributes description
+ * @param number number of data lines to generate
+ */
+ public static double[][] randomDoubles(Random rng, CharSequence descriptor, boolean regression, int number)
+ throws DescriptorException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ double[][] data = new double[number][];
+
+ for (int index = 0; index < number; index++) {
+ data[index] = randomVector(rng, attrs, regression);
+ }
+
+ return data;
+ }
+
+ /**
+ * Generates random data
+ *
+ * @param rng Random number generator
+ * @param nbAttributes number of attributes
+ * @param regression true is the label should be numerical
+ * @param size data size
+ */
+ public static Data randomData(Random rng, int nbAttributes, boolean regression, int size) throws DescriptorException {
+ String descriptor = randomDescriptor(rng, nbAttributes);
+ double[][] source = randomDoubles(rng, descriptor, regression, size);
+ String[] sData = double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, regression, sData);
+
+ return DataLoader.loadData(dataset, sData);
+ }
+
+ /**
+ * generates a random vector based on the given attributes.<br>
+ * the attributes' values are generated as follows :<br>
+ * <ul>
+ * <li>each IGNORED attribute receives a Double.NaN</li>
+ * <li>each NUMERICAL attribute receives a random double</li>
+ * <li>each CATEGORICAL and LABEL attribute receives a random integer in the
+ * range [0, CATEGORICAL_RANGE[</li>
+ * </ul>
+ *
+ * @param attrs attributes description
+ */
+ private static double[] randomVector(Random rng, Attribute[] attrs, boolean regression) {
+ double[] vector = new double[attrs.length];
+
+ for (int attr = 0; attr < attrs.length; attr++) {
+ if (attrs[attr].isIgnored()) {
+ vector[attr] = Double.NaN;
+ } else if (attrs[attr].isNumerical()) {
+ vector[attr] = rng.nextDouble();
+ } else if (attrs[attr].isCategorical()) {
+ vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ } else { // LABEL
+ if (regression) {
+ vector[attr] = rng.nextDouble();
+ } else {
+ vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ }
+ }
+ }
+
+ return vector;
+ }
+
+ /**
+ * converts a double array to a comma-separated string
+ *
+ * @param v double array
+ * @return comma-separated string
+ */
+ private static String double2String(double[] v) {
+ StringBuilder builder = new StringBuilder();
+
+ for (double aV : v) {
+ builder.append(aV).append(',');
+ }
+
+ return builder.toString();
+ }
+
+ /**
+ * converts an array of double arrays to an array of comma-separated strings
+ *
+ * @param source array of double arrays
+ * @return array of comma-separated strings
+ */
+ public static String[] double2String(double[][] source) {
+ String[] output = new String[source.length];
+
+ for (int index = 0; index < source.length; index++) {
+ output[index] = double2String(source[index]);
+ }
+
+ return output;
+ }
+
+ /**
+ * Generates random data with same label value
+ *
+ * @param number data size
+ * @param value label value
+ */
+ public static double[][] randomDoublesWithSameLabel(Random rng,
+ CharSequence descriptor,
+ boolean regression,
+ int number,
+ int value) throws DescriptorException {
+ int label = findLabel(descriptor);
+
+ double[][] source = randomDoubles(rng, descriptor, regression, number);
+
+ for (int index = 0; index < number; index++) {
+ source[index][label] = value;
+ }
+
+ return source;
+ }
+
+ /**
+ * finds the label attribute's index
+ */
+ public static int findLabel(CharSequence descriptor) throws DescriptorException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+ return ArrayUtils.indexOf(attrs, Attribute.LABEL);
+ }
+
+ private static void writeDataToFile(String[] sData, Path path) throws IOException {
+ BufferedWriter output = null;
+ try {
+ output = Files.newWriter(new File(path.toString()), Charsets.UTF_8);
+ for (String line : sData) {
+ output.write(line);
+ output.write('\n');
+ }
+ } finally {
+ Closeables.close(output, false);
+ }
+
+ }
+
+ public static Path writeDataToTestFile(String[] sData) throws IOException {
+ Path testData = new Path("testdata/Data");
+ MahoutTestCase ca = new MahoutTestCase();
+ FileSystem fs = testData.getFileSystem(ca.getConfiguration());
+ if (!fs.exists(testData)) {
+ fs.mkdirs(testData);
+ }
+
+ Path path = new Path(testData, "DataLoaderTest.data");
+
+ writeDataToFile(sData, path);
+
+ return path;
+ }
+
+ /**
+ * Split the data into numMaps splits
+ */
+ public static String[][] splitData(String[] sData, int numMaps) {
+ int nbInstances = sData.length;
+ int partitionSize = nbInstances / numMaps;
+
+ String[][] splits = new String[numMaps][];
+
+ for (int partition = 0; partition < numMaps; partition++) {
+ int from = partition * partitionSize;
+ int to = partition == (numMaps - 1) ? nbInstances : (partition + 1) * partitionSize;
+
+ splits[partition] = Arrays.copyOfRange(sData, from, to);
+ }
+
+ return splits;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
new file mode 100644
index 0000000..6a17aa2
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemRecordReader;
+import org.junit.Test;
+@Deprecated
+public final class InMemInputFormatTest extends MahoutTestCase {
+
+ @Test
+ public void testSplits() throws Exception {
+ int n = 1;
+ int maxNumSplits = 100;
+ int maxNbTrees = 1000;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int numSplits = rng.nextInt(maxNumSplits) + 1;
+ int nbTrees = rng.nextInt(maxNbTrees) + 1;
+
+ Configuration conf = getConfiguration();
+ Builder.setNbTrees(conf, nbTrees);
+
+ InMemInputFormat inputFormat = new InMemInputFormat();
+ List<InputSplit> splits = inputFormat.getSplits(conf, numSplits);
+
+ assertEquals(numSplits, splits.size());
+
+ int nbTreesPerSplit = nbTrees / numSplits;
+ int totalTrees = 0;
+ int expectedId = 0;
+
+ for (int index = 0; index < numSplits; index++) {
+ assertTrue(splits.get(index) instanceof InMemInputSplit);
+
+ InMemInputSplit split = (InMemInputSplit) splits.get(index);
+
+ assertEquals(expectedId, split.getFirstId());
+
+ if (index < numSplits - 1) {
+ assertEquals(nbTreesPerSplit, split.getNbTrees());
+ } else {
+ assertEquals(nbTrees - totalTrees, split.getNbTrees());
+ }
+
+ totalTrees += split.getNbTrees();
+ expectedId += split.getNbTrees();
+ }
+ }
+ }
+
+ @Test
+ public void testRecordReader() throws Exception {
+ int n = 1;
+ int maxNumSplits = 100;
+ int maxNbTrees = 1000;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int numSplits = rng.nextInt(maxNumSplits) + 1;
+ int nbTrees = rng.nextInt(maxNbTrees) + 1;
+
+ Configuration conf = getConfiguration();
+ Builder.setNbTrees(conf, nbTrees);
+
+ InMemInputFormat inputFormat = new InMemInputFormat();
+ List<InputSplit> splits = inputFormat.getSplits(conf, numSplits);
+
+ for (int index = 0; index < numSplits; index++) {
+ InMemInputSplit split = (InMemInputSplit) splits.get(index);
+ InMemRecordReader reader = new InMemRecordReader(split);
+
+ reader.initialize(split, null);
+
+ for (int tree = 0; tree < split.getNbTrees(); tree++) {
+ // reader.next() should return true until there is no tree left
+ assertEquals(tree < split.getNbTrees(), reader.nextKeyValue());
+ assertEquals(split.getFirstId() + tree, reader.getCurrentKey().get());
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
new file mode 100644
index 0000000..aeea084
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
@@ -0,0 +1,77 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.junit.Before;
+import org.junit.Test;
+@Deprecated
+public final class InMemInputSplitTest extends MahoutTestCase {
+
+ private Random rng;
+ private ByteArrayOutputStream byteOutStream;
+ private DataOutput out;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ byteOutStream = new ByteArrayOutputStream();
+ out = new DataOutputStream(byteOutStream);
+ }
+
+ /**
+ * Make sure that all the fields are processed correctly
+ */
+ @Test
+ public void testWritable() throws Exception {
+ InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), rng.nextLong());
+
+ split.write(out);
+ assertEquals(split, readSplit());
+ }
+
+ /**
+ * test the case seed == null
+ */
+ @Test
+ public void testNullSeed() throws Exception {
+ InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), null);
+
+ split.write(out);
+ assertEquals(split, readSplit());
+ }
+
+ private InMemInputSplit readSplit() throws IOException {
+ ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
+ DataInput in = new DataInputStream(byteInStream);
+ return InMemInputSplit.read(in);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
new file mode 100644
index 0000000..2821034
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
@@ -0,0 +1,197 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.junit.Test;
+@Deprecated
+public final class PartialBuilderTest extends MahoutTestCase {
+
+ private static final int NUM_MAPS = 5;
+
+ private static final int NUM_TREES = 32;
+
+ /** instances per partition */
+ private static final int NUM_INSTANCES = 20;
+
+ @Test
+ public void testProcessOutput() throws Exception {
+ Configuration conf = getConfiguration();
+ conf.setInt("mapred.map.tasks", NUM_MAPS);
+
+ Random rng = RandomUtils.getRandom();
+
+ // prepare the output
+ TreeID[] keys = new TreeID[NUM_TREES];
+ MapredOutput[] values = new MapredOutput[NUM_TREES];
+ int[] firstIds = new int[NUM_MAPS];
+ randomKeyValues(rng, keys, values, firstIds);
+
+ // store the output in a sequence file
+ Path base = getTestTempDirPath("testdata");
+ FileSystem fs = base.getFileSystem(conf);
+
+ Path outputFile = new Path(base, "PartialBuilderTest.seq");
+ Writer writer = SequenceFile.createWriter(fs, conf, outputFile,
+ TreeID.class, MapredOutput.class);
+
+ try {
+ for (int index = 0; index < NUM_TREES; index++) {
+ writer.append(keys[index], values[index]);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ // load the output and make sure its valid
+ TreeID[] newKeys = new TreeID[NUM_TREES];
+ Node[] newTrees = new Node[NUM_TREES];
+
+ PartialBuilder.processOutput(new Job(conf), base, newKeys, newTrees);
+
+ // check the forest
+ for (int tree = 0; tree < NUM_TREES; tree++) {
+ assertEquals(values[tree].getTree(), newTrees[tree]);
+ }
+
+ assertTrue("keys not equal", Arrays.deepEquals(keys, newKeys));
+ }
+
+ /**
+ * Make sure that the builder passes the good parameters to the job
+ *
+ */
+ @Test
+ public void testConfigure() {
+ TreeBuilder treeBuilder = new DefaultTreeBuilder();
+ Path dataPath = new Path("notUsedDataPath");
+ Path datasetPath = new Path("notUsedDatasetPath");
+ Long seed = 5L;
+
+ new PartialBuilderChecker(treeBuilder, dataPath, datasetPath, seed);
+ }
+
+ /**
+ * Generates random (key, value) pairs. Shuffles the partition's order
+ *
+ * @param rng
+ * @param keys
+ * @param values
+ * @param firstIds partitions's first ids in hadoop's order
+ */
+ private static void randomKeyValues(Random rng, TreeID[] keys, MapredOutput[] values, int[] firstIds) {
+ int index = 0;
+ int firstId = 0;
+ Collection<Integer> partitions = Lists.newArrayList();
+
+ for (int p = 0; p < NUM_MAPS; p++) {
+ // select a random partition, not yet selected
+ int partition;
+ do {
+ partition = rng.nextInt(NUM_MAPS);
+ } while (partitions.contains(partition));
+
+ partitions.add(partition);
+
+ int nbTrees = Step1Mapper.nbTrees(NUM_MAPS, NUM_TREES, partition);
+
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ Node tree = new Leaf(rng.nextInt(100));
+
+ keys[index] = new TreeID(partition, treeId);
+ values[index] = new MapredOutput(tree, nextIntArray(rng, NUM_INSTANCES));
+
+ index++;
+ }
+
+ firstIds[p] = firstId;
+ firstId += NUM_INSTANCES;
+ }
+
+ }
+
+ private static int[] nextIntArray(Random rng, int size) {
+ int[] array = new int[size];
+ for (int index = 0; index < size; index++) {
+ array[index] = rng.nextInt(101) - 1;
+ }
+
+ return array;
+ }
+
+ static class PartialBuilderChecker extends PartialBuilder {
+
+ private final Long seed;
+
+ private final TreeBuilder treeBuilder;
+
+ private final Path datasetPath;
+
+ PartialBuilderChecker(TreeBuilder treeBuilder, Path dataPath,
+ Path datasetPath, Long seed) {
+ super(treeBuilder, dataPath, datasetPath, seed);
+
+ this.seed = seed;
+ this.treeBuilder = treeBuilder;
+ this.datasetPath = datasetPath;
+ }
+
+ @Override
+ protected boolean runJob(Job job) throws IOException {
+ // no need to run the job, just check if the params are correct
+
+ Configuration conf = job.getConfiguration();
+
+ assertEquals(seed, getRandomSeed(conf));
+
+ // PartialBuilder should detect the 'local' mode and overrides the number
+ // of map tasks
+ assertEquals(1, conf.getInt("mapred.map.tasks", -1));
+
+ assertEquals(NUM_TREES, getNbTrees(conf));
+
+ assertFalse(isOutput(conf));
+
+ assertEquals(treeBuilder, getTreeBuilder(conf));
+
+ assertEquals(datasetPath, getDistributedCacheFile(conf, 0));
+
+ return true;
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
new file mode 100644
index 0000000..c5aec7f
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import org.easymock.EasyMock;
+import java.util.Random;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.easymock.Capture;
+import org.easymock.CaptureType;
+import org.junit.Test;
+@Deprecated
+public final class Step1MapperTest extends MahoutTestCase {
+
+ /**
+ * Make sure that the data used to build the trees is from the mapper's
+ * partition
+ *
+ */
+ private static class MockTreeBuilder implements TreeBuilder {
+
+ private Data expected;
+
+ public void setExpected(Data data) {
+ expected = data;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+ for (int index = 0; index < data.size(); index++) {
+ assertTrue(expected.contains(data.get(index)));
+ }
+
+ return new Leaf(Double.NaN);
+ }
+ }
+
+ /**
+ * Special Step1Mapper that can be configured without using a Configuration
+ *
+ */
+ private static class MockStep1Mapper extends Step1Mapper {
+ private MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+ int partition, int numMapTasks, int numTrees) {
+ configure(false, treeBuilder, dataset);
+ configure(seed, partition, numMapTasks, numTrees);
+ }
+ }
+
+ private static class TreeIDCapture extends Capture<TreeID> {
+
+ private TreeIDCapture() {
+ super(CaptureType.ALL);
+ }
+
+ @Override
+ public void setValue(final TreeID value) {
+ super.setValue(value.clone());
+ }
+ }
+
+ /** nb attributes per generated data instance */
+ static final int NUM_ATTRIBUTES = 4;
+
+ /** nb generated data instances */
+ static final int NUM_INSTANCES = 100;
+
+ /** nb trees to build */
+ static final int NUM_TREES = 10;
+
+ /** nb mappers to use */
+ static final int NUM_MAPPERS = 2;
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testMapper() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ // prepare the data
+ String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, NUM_INSTANCES);
+ String[] sData = Utils.double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ String[][] splits = Utils.splitData(sData, NUM_MAPPERS);
+
+ MockTreeBuilder treeBuilder = new MockTreeBuilder();
+
+ LongWritable key = new LongWritable();
+ Text value = new Text();
+
+ int treeIndex = 0;
+
+ for (int partition = 0; partition < NUM_MAPPERS; partition++) {
+ String[] split = splits[partition];
+ treeBuilder.setExpected(DataLoader.loadData(dataset, split));
+
+ // expected number of trees that this mapper will build
+ int mapNbTrees = Step1Mapper.nbTrees(NUM_MAPPERS, NUM_TREES, partition);
+
+ Mapper.Context context = EasyMock.createNiceMock(Mapper.Context.class);
+ Capture<TreeID> capturedKeys = new TreeIDCapture();
+ context.write(EasyMock.capture(capturedKeys), EasyMock.anyObject());
+ EasyMock.expectLastCall().anyTimes();
+
+ EasyMock.replay(context);
+
+ MockStep1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, null,
+ partition, NUM_MAPPERS, NUM_TREES);
+
+ // make sure the mapper computed firstTreeId correctly
+ assertEquals(treeIndex, mapper.getFirstTreeId());
+
+ for (int index = 0; index < split.length; index++) {
+ key.set(index);
+ value.set(split[index]);
+ mapper.map(key, value, context);
+ }
+
+ mapper.cleanup(context);
+ EasyMock.verify(context);
+
+ // make sure the mapper built all its trees
+ assertEquals(mapNbTrees, capturedKeys.getValues().size());
+
+ // check the returned keys
+ for (TreeID k : capturedKeys.getValues()) {
+ assertEquals(partition, k.partition());
+ assertEquals(treeIndex, k.treeId());
+
+ treeIndex++;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
new file mode 100644
index 0000000..c4beeaf
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+@Deprecated
+public final class TreeIDTest extends MahoutTestCase {
+
+ @Test
+ public void testTreeID() {
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < 1000000; nloop++) {
+ int partition = Math.abs(rng.nextInt());
+ int treeId = rng.nextInt(TreeID.MAX_TREEID);
+
+ TreeID t1 = new TreeID(partition, treeId);
+
+ assertEquals(partition, t1.partition());
+ assertEquals(treeId, t1.treeId());
+
+ TreeID t2 = new TreeID();
+ t2.set(partition, treeId);
+
+ assertEquals(partition, t2.partition());
+ assertEquals(treeId, t2.treeId());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
new file mode 100644
index 0000000..1300926
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Before;
+import org.junit.Test;
+@Deprecated
+public final class NodeTest extends MahoutTestCase {
+
+ private Random rng;
+
+ private ByteArrayOutputStream byteOutStream;
+ private DataOutput out;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+
+ byteOutStream = new ByteArrayOutputStream();
+ out = new DataOutputStream(byteOutStream);
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.node.Node#read(java.io.DataInput)}.
+ */
+ @Test
+ public void testReadTree() throws Exception {
+ Node node1 = new CategoricalNode(rng.nextInt(),
+ new double[] { rng.nextDouble(), rng.nextDouble() },
+ new Node[] { new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()) });
+ Node node2 = new NumericalNode(rng.nextInt(), rng.nextDouble(),
+ new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()));
+
+ Node root = new CategoricalNode(rng.nextInt(),
+ new double[] { rng.nextDouble(), rng.nextDouble(), rng.nextDouble() },
+ new Node[] { node1, node2, new Leaf(rng.nextDouble()) });
+
+ // write the node to a DataOutput
+ root.write(out);
+
+ // read the node back
+ assertEquals(root, readNode());
+ }
+
+ Node readNode() throws IOException {
+ ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
+ DataInput in = new DataInputStream(byteInStream);
+ return Node.read(in);
+ }
+
+ @Test
+ public void testReadLeaf() throws Exception {
+
+ Node leaf = new Leaf(rng.nextDouble());
+ leaf.write(out);
+ assertEquals(leaf, readNode());
+ }
+
+ @Test
+ public void testParseNumerical() throws Exception {
+
+ Node node = new NumericalNode(rng.nextInt(), rng.nextDouble(), new Leaf(rng
+ .nextInt()), new Leaf(rng.nextDouble()));
+ node.write(out);
+ assertEquals(node, readNode());
+ }
+
+ @Test
+ public void testCategoricalNode() throws Exception {
+
+ Node node = new CategoricalNode(rng.nextInt(), new double[]{rng.nextDouble(),
+ rng.nextDouble(), rng.nextDouble()}, new Node[]{
+ new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()),
+ new Leaf(rng.nextDouble())});
+
+ node.write(out);
+ assertEquals(node, readNode());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
new file mode 100644
index 0000000..94d0ad9
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.junit.Test;
+@Deprecated
+public final class DefaultIgSplitTest extends MahoutTestCase {
+
+ private static final int NUM_ATTRIBUTES = 10;
+
+ @Test
+ public void testEntropy() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
+ int label = Utils.findLabel(descriptor);
+
+ // all the vectors have the same label (0)
+ double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0);
+ String[] sData = Utils.double2String(temp);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+ DefaultIgSplit iG = new DefaultIgSplit();
+
+ double expected = 0.0 - 1.0 * Math.log(1.0) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+
+ // 50/100 of the vectors have the label (1)
+ // 50/100 of the vectors have the label (0)
+ for (int index = 0; index < 50; index++) {
+ temp[index][label] = 1.0;
+ }
+ sData = Utils.double2String(temp);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ iG = new DefaultIgSplit();
+
+ expected = 2.0 * -0.5 * Math.log(0.5) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+
+ // 15/100 of the vectors have the label (2)
+ // 35/100 of the vectors have the label (1)
+ // 50/100 of the vectors have the label (0)
+ for (int index = 0; index < 15; index++) {
+ temp[index][label] = 2.0;
+ }
+ sData = Utils.double2String(temp);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ iG = new DefaultIgSplit();
+
+ expected = -0.15 * Math.log(0.15) / Math.log(2.0) - 0.35 * Math.log(0.35)
+ / Math.log(2.0) - 0.5 * Math.log(0.5) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
new file mode 100644
index 0000000..9c5893a
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+@Deprecated
+public final class RegressionSplitTest extends MahoutTestCase {
+
+ private static Data[] generateTrainingData() throws DescriptorException {
+ // Training data
+ String[] trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 3 == 0) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ } else if (i % 3 == 1) {
+ trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+ } else {
+ trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+ }
+ }
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+ Data[] datas = new Data[3];
+ datas[0] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 2 == 0) {
+ trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+ } else {
+ trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+ }
+ }
+ datas[1] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[10];
+ for (int i = 0; i < trainData.length; i++) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ }
+ datas[2] = DataLoader.loadData(dataset, trainData);
+
+ return datas;
+ }
+
+ @Test
+ public void testComputeSplit() throws DescriptorException {
+ Data[] datas = generateTrainingData();
+
+ RegressionSplit igSplit = new RegressionSplit();
+ Split split = igSplit.computeSplit(datas[0], 1);
+ assertEquals(180.0, split.getIg(), EPSILON);
+ assertEquals(38.0, split.getSplit(), EPSILON);
+ split = igSplit.computeSplit(datas[0].subset(Condition.lesser(1, 38.0)), 1);
+ assertEquals(76.5, split.getIg(), EPSILON);
+ assertEquals(21.5, split.getSplit(), EPSILON);
+
+ split = igSplit.computeSplit(datas[1], 0);
+ assertEquals(2205.0, split.getIg(), EPSILON);
+ assertEquals(Double.NaN, split.getSplit(), EPSILON);
+ split = igSplit.computeSplit(datas[1].subset(Condition.equals(0, 0.0)), 1);
+ assertEquals(250.0, split.getIg(), EPSILON);
+ assertEquals(41.0, split.getSplit(), EPSILON);
+ }
+}