You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:14 UTC
[02/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index 5f8518b..d682093 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -18,15 +18,23 @@
*/
package hivemall.smile.classification;
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
+import hivemall.utils.codec.Base91;
import hivemall.utils.lang.mutable.MutableInt;
import java.io.BufferedInputStream;
+import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
+import java.io.InputStreamReader;
import java.net.URL;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
+
+import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.Collector;
@@ -34,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
@@ -43,7 +53,7 @@ import smile.data.parser.ArffParser;
public class RandomForestClassifierUDTFTest {
@Test
- public void testIris() throws IOException, ParseException, HiveException {
+ public void testIrisDense() throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
@@ -85,4 +95,278 @@ public class RandomForestClassifierUDTFTest {
Assert.assertEquals(49, count.getValue());
}
+ @Test
+ public void testIrisSparse() throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(49, count.getValue());
+ }
+
+ @Test
+ public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException {
+ String urlString = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
+ DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString);
+ DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString);
+
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+
+ int diff = 0;
+ for (int i = 0; i < size; i++) {
+ if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
+ diff++;
+ }
+ }
+
+ Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
+ }
+
+ private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, x[i][j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ return node;
+ }
+
+ private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ final double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ return node;
+ }
+
+ @Test
+ public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
+ final int numTrees = 10;
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-stratified_sampling -seed 71 -trees " + numTrees);
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+
+ BufferedReader news20 = readFile("news20-multiclass.gz");
+ ArrayList<String> features = new ArrayList<String>();
+ String line = news20.readLine();
+ while (line != null) {
+ StringTokenizer tokens = new StringTokenizer(line, " ");
+ int label = Integer.parseInt(tokens.nextToken());
+ while (tokens.hasMoreTokens()) {
+ features.add(tokens.nextToken());
+ }
+ Assert.assertFalse(features.isEmpty());
+ udtf.process(new Object[] {features, label});
+
+ features.clear();
+ line = news20.readLine();
+ }
+ news20.close();
+
+ final MutableInt count = new MutableInt(0);
+ final MutableInt oobErrors = new MutableInt(0);
+ final MutableInt oobTests = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ oobErrors.addValue(((IntWritable) forward[4]).get());
+ oobTests.addValue(((IntWritable) forward[5]).get());
+ count.addValue(1);
+ }
+ };
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(numTrees, count.getValue());
+ float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
+ // TODO why multi-class classification so bad??
+ Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
+ }
+
+ @Test
+ public void testNews20BinarySparse() throws IOException, ParseException, HiveException {
+ final int numTrees = 10;
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees "
+ + numTrees);
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ BufferedReader news20 = readFile("news20-small.binary.gz");
+ ArrayList<String> features = new ArrayList<String>();
+ String line = news20.readLine();
+ while (line != null) {
+ StringTokenizer tokens = new StringTokenizer(line, " ");
+ int label = Integer.parseInt(tokens.nextToken());
+ if (label == -1) {
+ label = 0;
+ }
+ while (tokens.hasMoreTokens()) {
+ features.add(tokens.nextToken());
+ }
+ if (!features.isEmpty()) {
+ udtf.process(new Object[] {features, label});
+ features.clear();
+ }
+ line = news20.readLine();
+ }
+ news20.close();
+
+ final MutableInt count = new MutableInt(0);
+ final MutableInt oobErrors = new MutableInt(0);
+ final MutableInt oobTests = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ oobErrors.addValue(((IntWritable) forward[4]).get());
+ oobTests.addValue(((IntWritable) forward[5]).get());
+ count.addValue(1);
+ }
+ };
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(numTrees, count.getValue());
+ float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
+ Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3);
+ }
+
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ InputStream is = KernelExpansionPassiveAggressiveUDTF.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
index 20f44b3..eae625d 100644
--- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
@@ -18,7 +18,16 @@
*/
package hivemall.smile.regression;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.data.Attribute;
+import hivemall.smile.data.Attribute.NumericAttribute;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.junit.Assert;
@@ -30,7 +39,7 @@ import smile.validation.LOOCV;
public class RegressionTreeTest {
@Test
- public void testPredict() {
+ public void testPredictDense() {
double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
{259.426, 232.5, 145.6, 108.632, 1948, 61.122},
@@ -53,10 +62,51 @@ public class RegressionTreeTest {
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
- for (int i = 0; i < attrs.length; i++) {
- attrs[i] = new Attribute.NumericAttribute(i);
+ Arrays.fill(attrs, new NumericAttribute());
+
+ int n = longley.length;
+ LOOCV loocv = new LOOCV(n);
+ double rss = 0.0;
+ for (int i = 0; i < n; i++) {
+ double[][] trainx = Math.slice(longley, loocv.train[i]);
+ double[] trainy = Math.slice(y, loocv.train[i]);
+ int maxLeafs = 10;
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+
+ double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]);
+ rss += r * r;
}
+ Assert.assertTrue("MSE = " + (rss / n), (rss / n) < 42);
+ }
+
+ @Test
+ public void testPredictSparse() {
+
+ double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019},
+ {419.180, 282.2, 285.7, 118.734, 1956, 67.857},
+ {442.769, 293.6, 279.8, 120.445, 1957, 68.169},
+ {444.546, 468.1, 263.7, 121.950, 1958, 66.513},
+ {482.704, 381.3, 255.2, 123.366, 1959, 68.655},
+ {502.601, 393.1, 251.4, 125.368, 1960, 69.564},
+ {518.173, 480.6, 257.2, 127.852, 1961, 69.331},
+ {554.894, 400.7, 282.7, 130.081, 1962, 70.551}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
+ 112.6, 114.2, 115.7, 116.9};
+
+ Attribute[] attrs = new Attribute[longley[0].length];
+ Arrays.fill(attrs, new NumericAttribute());
+
int n = longley.length;
LOOCV loocv = new LOOCV(n);
double rss = 0.0;
@@ -64,8 +114,8 @@ public class RegressionTreeTest {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = 10;
- smile.math.Random rand = new smile.math.Random(i);
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs, rand);
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, false), trainy,
+ maxLeafs, RandomNumberGeneratorFactory.createPRNG(i));
double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]);
rss += r * r;
@@ -98,9 +148,7 @@ public class RegressionTreeTest {
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
- for (int i = 0; i < attrs.length; i++) {
- attrs[i] = new Attribute.NumericAttribute(i);
- }
+ Arrays.fill(attrs, new NumericAttribute());
int n = longley.length;
LOOCV loocv = new LOOCV(n);
@@ -108,7 +156,7 @@ public class RegressionTreeTest {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = Integer.MAX_VALUE;
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs);
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs);
byte[] b = tree.predictSerCodegen(true);
RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true);
@@ -119,4 +167,19 @@ public class RegressionTreeTest {
Assert.assertEquals(expected, actual, 0.d);
}
}
+
+ @Nonnull
+ private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
+ if (dense) {
+ return new RowMajorDenseMatrix2d(x, x[0].length);
+ } else {
+ int numRows = x.length;
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ for (int i = 0; i < numRows; i++) {
+ builder.nextRow(x[i]);
+ }
+ return builder.buildMatrix();
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
index 504ea86..65feeeb 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
@@ -18,13 +18,12 @@
*/
package hivemall.smile.tools;
-import static org.junit.Assert.assertEquals;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
+import hivemall.utils.codec.Base91;
import hivemall.utils.lang.ArrayUtils;
import java.io.BufferedInputStream;
@@ -42,6 +41,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.junit.Assert;
import org.junit.Test;
import smile.data.AttributeDataset;
@@ -49,7 +50,7 @@ import smile.data.parser.ArffParser;
import smile.math.Math;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;
-import smile.validation.Validation;
+import smile.validation.RMSE;
public class TreePredictUDFTest {
private static final boolean DEBUG = false;
@@ -76,8 +77,9 @@ public class TreePredictUDFTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
+ DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ x[0].length), trainy, 4);
+ Assert.assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
}
}
@@ -103,10 +105,11 @@ public class TreePredictUDFTest {
double[][] testx = Math.slice(datax, cv.test[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20);
+ RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ trainx[0].length), trainy, 20);
for (int j = 0; j < testx.length; j++) {
- assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
+ Assert.assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
}
}
}
@@ -142,52 +145,60 @@ public class TreePredictUDFTest {
}
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20);
- debugPrint(String.format("RMSE = %.4f\n", Validation.test(tree, testx, testy)));
+ RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ trainx[0].length), trainy, 20);
+ debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
for (int i = m; i < n; i++) {
- assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
+ Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
}
}
+ private static <T> double rmse(RegressionTree regression, double[][] x, double[] y) {
+ final int n = x.length;
+ final double[] predictions = new double[n];
+ for (int i = 0; i < n; i++) {
+ predictions[i] = regression.predict(x[i]);
+ }
+ return new RMSE().measure(y, predictions);
+ }
+
private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException {
- String opScript = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(opScript);
+ byte[] b = tree.predictSerCodegen(true);
+ byte[] encoded = Base91.encode(b);
+ Text model = new Text(encoded);
TreePredictUDF udf = new TreePredictUDF();
udf.initialize(new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.writableStringObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()),
- new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)),
+ new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)),
new DeferredJavaObject(true)};
- IntWritable result = (IntWritable) udf.evaluate(arguments);
+ Object[] result = (Object[]) udf.evaluate(arguments);
udf.close();
- return result.get();
+ return ((IntWritable) result[0]).get();
}
private static double evalPredict(RegressionTree tree, double[] x) throws HiveException,
IOException {
- String opScript = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(opScript);
+ byte[] b = tree.predictSerCodegen(true);
+ byte[] encoded = Base91.encode(b);
+ Text model = new Text(encoded);
TreePredictUDF udf = new TreePredictUDF();
udf.initialize(new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.writableStringObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)});
DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()),
- new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)),
+ new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)),
new DeferredJavaObject(false)};
DoubleWritable result = (DoubleWritable) udf.evaluate(arguments);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java b/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
deleted file mode 100644
index 4a2dcd8..0000000
--- a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.vm;
-
-import static org.junit.Assert.assertEquals;
-import hivemall.utils.io.IOUtils;
-
-import java.io.BufferedInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.URL;
-import java.text.ParseException;
-import java.util.ArrayList;
-
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.junit.Assert;
-import org.junit.Test;
-
-public class StackMachineTest {
- private static final boolean DEBUG = false;
-
- @Test
- public void testFindInfinteLoop() throws IOException, ParseException, HiveException,
- VMRuntimeException {
- // Sample of machine code having infinite loop
- ArrayList<String> opScript = new ArrayList<String>();
- opScript.add("push 2.0");
- opScript.add("push 1.0");
- opScript.add("iflt 0");
- opScript.add("push 1");
- opScript.add("call end");
- debugPrint(opScript);
- double[] x = new double[0];
- StackMachine sm = new StackMachine();
- try {
- sm.run(opScript, x);
- Assert.fail("VMRuntimeException is expected");
- } catch (VMRuntimeException ex) {
- assertEquals("There is a infinite loop in the Machine code.", ex.getMessage());
- }
- }
-
- @Test
- public void testLargeOpcodes() throws IOException, ParseException, HiveException,
- VMRuntimeException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/b1a8e588f5750e3b658c/raw/a4074d37400dab2b13a2f43d81f5166188d3461a/vmtest01.txt");
- InputStream is = new BufferedInputStream(url.openStream());
- String opScript = IOUtils.toString(is);
-
- StackMachine sm = new StackMachine();
- sm.compile(opScript);
-
- double[] x1 = new double[] {36, 2, 1, 2, 0, 436, 1, 0, 0, 13, 0, 567, 1, 595, 2, 1};
- sm.eval(x1);
- assertEquals(0.d, sm.getResult().doubleValue(), 0d);
-
- double[] x2 = {31, 2, 1, 2, 0, 354, 1, 0, 0, 30, 0, 502, 1, 9, 2, 2};
- sm.eval(x2);
- assertEquals(1.d, sm.getResult().doubleValue(), 0d);
-
- double[] x3 = {39, 0, 0, 0, 0, 1756, 0, 0, 0, 3, 0, 939, 1, 0, 0, 0};
- sm.eval(x3);
- assertEquals(0.d, sm.getResult().doubleValue(), 0d);
- }
-
- private static void debugPrint(Object msg) {
- if (DEBUG) {
- System.out.println(msg);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
deleted file mode 100644
index 177a345..0000000
--- a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class DoubleArray3DTest {
-
- @Test
- public void test() {
- final int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
- @Test
- public void testConfigureExpand() {
- int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- size_i = 101;
- size_j = 101;
- size_k = 11;
- mdarray.configure(size_i, size_j, size_k);
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity());
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
- @Test
- public void testConfigureShrink() {
- int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- int capacity = mdarray.getCapacity();
- size_i = 49;
- size_j = 49;
- size_k = 4;
- mdarray.configure(size_i, size_j, size_k);
- Assert.assertEquals(capacity, mdarray.getCapacity());
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
deleted file mode 100644
index 72e76e8..0000000
--- a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class DoubleArrayTest {
-
- @Test
- public void testSparseDoubleArrayToArray() {
- SparseDoubleArray array = new SparseDoubleArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(10, array.size());
- Assert.assertEquals(10, array.toArray(false).length);
-
- double[] copied = array.toArray(true);
- Assert.assertEquals(10, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i], 0.d);
- }
- }
-
- @Test
- public void testSparseDoubleArrayClear() {
- SparseDoubleArray array = new SparseDoubleArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- array.clear();
- Assert.assertEquals(0, array.size());
- Assert.assertEquals(0, array.get(0), 0.d);
- for (int i = 0; i < 5; i++) {
- array.put(i, 100 + i);
- }
- Assert.assertEquals(5, array.size());
- for (int i = 0; i < 5; i++) {
- Assert.assertEquals(100 + i, array.get(i), 0.d);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
deleted file mode 100644
index 8a8a68d..0000000
--- a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class Int2FloatOpenHashMapTest {
-
- @Test
- public void testSize() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- map.put(1, 3.f);
- Assert.assertEquals(3.f, map.get(1), 0.d);
- map.put(1, 5.f);
- Assert.assertEquals(5.f, map.get(1), 0.d);
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testDefaultReturnValue() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- Assert.assertEquals(0, map.size());
- Assert.assertEquals(-1.f, map.get(1), 0.d);
- float ret = Float.MIN_VALUE;
- map.defaultReturnValue(ret);
- Assert.assertEquals(ret, map.get(1), 0.d);
- }
-
- @Test
- public void testPutAndGet() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Float v = map.get(i);
- Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d);
- }
- }
-
- @Test
- public void testIterator() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000);
- Int2FloatOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- Float v = itor.getValue();
- Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIterator2() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100);
- map.put(33, 3.16f);
-
- Int2FloatOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- Assert.assertNotEquals(-1, itor.next());
- Assert.assertEquals(33, itor.getKey());
- Assert.assertEquals(3.16f, itor.getValue(), 0.d);
- Assert.assertEquals(-1, itor.next());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
deleted file mode 100644
index 1186bdf..0000000
--- a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ObjectUtils;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class Int2LongOpenHashMapTest {
-
- @Test
- public void testSize() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- map.put(1, 3L);
- Assert.assertEquals(3L, map.get(1));
- map.put(1, 5L);
- Assert.assertEquals(5L, map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testDefaultReturnValue() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- Assert.assertEquals(0, map.size());
- Assert.assertEquals(-1L, map.get(1));
- long ret = Long.MIN_VALUE;
- map.defaultReturnValue(ret);
- Assert.assertEquals(ret, map.get(1));
- }
-
- @Test
- public void testPutAndGet() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- long v = map.get(i);
- Assert.assertEquals(i, v);
- }
- }
-
- @Test
- public void testSerde() throws IOException, ClassNotFoundException {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
-
- byte[] b = ObjectUtils.toCompressedBytes(map);
- map = new Int2LongOpenHashTable(16384);
- ObjectUtils.readCompressedObject(b, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- long v = map.get(i);
- Assert.assertEquals(i, v);
- }
- }
-
- @Test
- public void testIterator() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000);
- Int2LongOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- long v = itor.getValue();
- Assert.assertEquals(k, v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
deleted file mode 100644
index 42852ea..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntArrayTest {
-
- @Test
- public void testFixedIntArrayToArray() {
- FixedIntArray array = new FixedIntArray(11);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(11, array.size());
- Assert.assertEquals(11, array.toArray(false).length);
-
- int[] copied = array.toArray(true);
- Assert.assertEquals(11, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i]);
- }
- }
-
- @Test
- public void testSparseIntArrayToArray() {
- SparseIntArray array = new SparseIntArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(10, array.size());
- Assert.assertEquals(10, array.toArray(false).length);
-
- int[] copied = array.toArray(true);
- Assert.assertEquals(10, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i]);
- }
- }
-
- @Test
- public void testSparseIntArrayClear() {
- SparseIntArray array = new SparseIntArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- array.clear();
- Assert.assertEquals(0, array.size());
- Assert.assertEquals(0, array.get(0));
- for (int i = 0; i < 5; i++) {
- array.put(i, 100 + i);
- }
- Assert.assertEquals(5, array.size());
- for (int i = 0; i < 5; i++) {
- Assert.assertEquals(100 + i, array.get(i));
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
deleted file mode 100644
index 29a5a81..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntOpenHashMapTest {
-
- @Test
- public void testSize() {
- IntOpenHashMap<Float> map = new IntOpenHashMap<Float>(16384);
- map.put(1, Float.valueOf(3.f));
- Assert.assertEquals(Float.valueOf(3.f), map.get(1));
- map.put(1, Float.valueOf(5.f));
- Assert.assertEquals(Float.valueOf(5.f), map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testPutAndGet() {
- IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Integer v = map.get(i);
- Assert.assertEquals(i, v.intValue());
- }
- }
-
- @Test
- public void testIterator() {
- IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(1000);
- IntOpenHashMap.IMapIterator<Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(k, v.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
deleted file mode 100644
index 3babb3d..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntOpenHashTableTest {
-
- @Test
- public void testSize() {
- IntOpenHashTable<Float> map = new IntOpenHashTable<Float>(16384);
- map.put(1, Float.valueOf(3.f));
- Assert.assertEquals(Float.valueOf(3.f), map.get(1));
- map.put(1, Float.valueOf(5.f));
- Assert.assertEquals(Float.valueOf(5.f), map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testPutAndGet() {
- IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Integer v = map.get(i);
- Assert.assertEquals(i, v.intValue());
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
deleted file mode 100644
index e3cc018..0000000
--- a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.mutable.MutableInt;
-
-import java.util.Map;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class OpenHashMapTest {
-
- @Test
- public void testPutAndGet() {
- Map<Object, Object> map = new OpenHashMap<Object, Object>(16384);
- final int numEntries = 5000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
- @Test
- public void testIterator() {
- OpenHashMap<String, Integer> map = new OpenHashMap<String, Integer>(1000);
- IMapIterator<String, Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(Integer.valueOf(k), v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIteratorGetProbe() {
- OpenHashMap<String, MutableInt> map = new OpenHashMap<String, MutableInt>(100);
- IMapIterator<String, MutableInt> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), new MutableInt(i));
- }
-
- final MutableInt probe = new MutableInt();
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- itor.getValue(probe);
- Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
deleted file mode 100644
index d5a465c..0000000
--- a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.mutable.MutableInt;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class OpenHashTableTest {
-
- @Test
- public void testPutAndGet() {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 5000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
- @Test
- public void testIterator() {
- OpenHashTable<String, Integer> map = new OpenHashTable<String, Integer>(1000);
- IMapIterator<String, Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(Integer.valueOf(k), v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIteratorGetProbe() {
- OpenHashTable<String, MutableInt> map = new OpenHashTable<String, MutableInt>(100);
- IMapIterator<String, MutableInt> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), new MutableInt(i));
- }
-
- final MutableInt probe = new MutableInt();
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- itor.getValue(probe);
- Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testSerDe() throws IOException, ClassNotFoundException {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 100000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- byte[] serialized = ObjectUtils.toBytes(map);
- map = new OpenHashTable<Object, Object>();
- ObjectUtils.readObject(serialized, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
-
- @Test
- public void testCompressedSerDe() throws IOException, ClassNotFoundException {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 100000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- byte[] serialized = ObjectUtils.toCompressedBytes(map);
- map = new OpenHashTable<Object, Object>();
- ObjectUtils.readCompressedObject(serialized, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
deleted file mode 100644
index 68d0f6d..0000000
--- a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class SparseIntArrayTest {
-
- @Test
- public void testDense() {
- int size = 1000;
- Random rand = new Random(31);
- int[] expected = new int[size];
- IntArray actual = new SparseIntArray(10);
- for (int i = 0; i < size; i++) {
- int r = rand.nextInt(size);
- expected[i] = r;
- actual.put(i, r);
- }
- for (int i = 0; i < size; i++) {
- Assert.assertEquals(expected[i], actual.get(i));
- }
- }
-
- @Test
- public void testSparse() {
- int size = 1000;
- Random rand = new Random(31);
- int[] expected = new int[size];
- SparseIntArray actual = new SparseIntArray(10);
- for (int i = 0; i < size; i++) {
- int key = rand.nextInt(size);
- int v = rand.nextInt();
- expected[key] = v;
- actual.put(key, v);
- }
- for (int i = 0; i < actual.size(); i++) {
- int key = actual.keyAt(i);
- Assert.assertEquals(expected[key], actual.get(key, 0));
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
new file mode 100644
index 0000000..4fdb43e
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.DoubleArray3D;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoubleArray3DTest {
+
+ @Test
+ public void test() {
+ final int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testConfigureExpand() {
+ int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ size_i = 101;
+ size_j = 101;
+ size_k = 11;
+ mdarray.configure(size_i, size_j, size_k);
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity());
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testConfigureShrink() {
+ int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ int capacity = mdarray.getCapacity();
+ size_i = 49;
+ size_j = 49;
+ size_k = 4;
+ mdarray.configure(size_i, size_j, size_k);
+ Assert.assertEquals(capacity, mdarray.getCapacity());
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
new file mode 100644
index 0000000..ab52717
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoubleArrayTest {
+
+ @Test
+ public void testSparseDoubleArrayToArray() {
+ SparseDoubleArray array = new SparseDoubleArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(10, array.size());
+ Assert.assertEquals(10, array.toArray(false).length);
+
+ double[] copied = array.toArray(true);
+ Assert.assertEquals(10, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i], 0.d);
+ }
+ }
+
+ @Test
+ public void testSparseDoubleArrayClear() {
+ SparseDoubleArray array = new SparseDoubleArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ array.clear();
+ Assert.assertEquals(0, array.size());
+ Assert.assertEquals(0, array.get(0), 0.d);
+ for (int i = 0; i < 5; i++) {
+ array.put(i, 100 + i);
+ }
+ Assert.assertEquals(5, array.size());
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(100 + i, array.get(i), 0.d);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
new file mode 100644
index 0000000..0ce3912
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.DenseIntArray;
+import hivemall.utils.collections.arrays.SparseIntArray;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntArrayTest {
+
+ @Test
+ public void testFixedIntArrayToArray() {
+ DenseIntArray array = new DenseIntArray(11);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(11, array.size());
+ Assert.assertEquals(11, array.toArray(false).length);
+
+ int[] copied = array.toArray(true);
+ Assert.assertEquals(11, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i]);
+ }
+ }
+
+ @Test
+ public void testSparseIntArrayToArray() {
+ SparseIntArray array = new SparseIntArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(10, array.size());
+ Assert.assertEquals(10, array.toArray(false).length);
+
+ int[] copied = array.toArray(true);
+ Assert.assertEquals(10, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i]);
+ }
+ }
+
+ @Test
+ public void testSparseIntArrayClear() {
+ SparseIntArray array = new SparseIntArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ array.clear();
+ Assert.assertEquals(0, array.size());
+ Assert.assertEquals(0, array.get(0));
+ for (int i = 0; i < 5; i++) {
+ array.put(i, 100 + i);
+ }
+ Assert.assertEquals(5, array.size());
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(100 + i, array.get(i));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
new file mode 100644
index 0000000..db3c8eb
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.IntArray;
+import hivemall.utils.collections.arrays.SparseIntArray;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SparseIntArrayTest {
+
+ @Test
+ public void testDense() {
+ int size = 1000;
+ Random rand = new Random(31);
+ int[] expected = new int[size];
+ IntArray actual = new SparseIntArray(10);
+ for (int i = 0; i < size; i++) {
+ int r = rand.nextInt(size);
+ expected[i] = r;
+ actual.put(i, r);
+ }
+ for (int i = 0; i < size; i++) {
+ Assert.assertEquals(expected[i], actual.get(i));
+ }
+ }
+
+ @Test
+ public void testSparse() {
+ int size = 1000;
+ Random rand = new Random(31);
+ int[] expected = new int[size];
+ SparseIntArray actual = new SparseIntArray(10);
+ for (int i = 0; i < size; i++) {
+ int key = rand.nextInt(size);
+ int v = rand.nextInt();
+ expected[key] = v;
+ actual.put(key, v);
+ }
+ for (int i = 0; i < actual.size(); i++) {
+ int key = actual.keyAt(i);
+ Assert.assertEquals(expected[key], actual.get(key, 0));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
new file mode 100644
index 0000000..c40ea7e
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LongArrayListTest {
+
+ @Test
+ public void testRemoveIndex() {
+ LongArrayList list = new LongArrayList();
+ list.add(0).add(1).add(2).add(3);
+ Assert.assertEquals(1, list.remove(1));
+ Assert.assertEquals(3, list.size());
+ Assert.assertArrayEquals(new long[] {0, 2, 3}, list.toArray());
+ Assert.assertEquals(3, list.remove(2));
+ Assert.assertArrayEquals(new long[] {0, 2}, list.toArray());
+ Assert.assertEquals(0, list.remove(0));
+ Assert.assertArrayEquals(new long[] {2}, list.toArray());
+ list.add(0).add(1);
+ Assert.assertEquals(3, list.size());
+ Assert.assertArrayEquals(new long[] {2, 0, 1}, list.toArray());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
new file mode 100644
index 0000000..6a2ff96
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class Int2FloatOpenHashMapTest {
+
+ @Test
+ public void testSize() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ map.put(1, 3.f);
+ Assert.assertEquals(3.f, map.get(1), 0.d);
+ map.put(1, 5.f);
+ Assert.assertEquals(5.f, map.get(1), 0.d);
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testDefaultReturnValue() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ Assert.assertEquals(0, map.size());
+ Assert.assertEquals(-1.f, map.get(1), 0.d);
+ float ret = Float.MIN_VALUE;
+ map.defaultReturnValue(ret);
+ Assert.assertEquals(ret, map.get(1), 0.d);
+ }
+
+ @Test
+ public void testPutAndGet() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Float v = map.get(i);
+ Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d);
+ }
+ }
+
+ @Test
+ public void testIterator() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000);
+ Int2FloatOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
+ }
+ Assert.assertEquals(numEntries, map.size());
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ int k = itor.getKey();
+ Float v = itor.getValue();
+ Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+ @Test
+ public void testIterator2() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100);
+ map.put(33, 3.16f);
+
+ Int2FloatOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ Assert.assertNotEquals(-1, itor.next());
+ Assert.assertEquals(33, itor.getKey());
+ Assert.assertEquals(3.16f, itor.getValue(), 0.d);
+ Assert.assertEquals(-1, itor.next());
+ }
+
+}