You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:15 UTC
[03/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
new file mode 100644
index 0000000..1fb3a08
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.sampling;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n
+ * items.
+ *
+ * @link http://en.wikipedia.org/wiki/Reservoir_sampling
+ * @link http://portal.acm.org/citation.cfm?id=3165
+ */
+public final class ReservoirSampler<T> {
+
+ private final T[] samples;
+ private final int numSamples;
+ private int position;
+
+ private final Random rand;
+
+ @SuppressWarnings("unchecked")
+ public ReservoirSampler(int sampleSize) {
+ if (sampleSize <= 0) {
+ throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize);
+ }
+ this.samples = (T[]) new Object[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ @SuppressWarnings("unchecked")
+ public ReservoirSampler(int sampleSize, long seed) {
+ this.samples = (T[]) new Object[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public ReservoirSampler(T[] samples) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public ReservoirSampler(T[] samples, long seed) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public T[] getSample() {
+ return samples;
+ }
+
+ public List<T> getSamplesAsList() {
+ return Arrays.asList(samples);
+ }
+
+ public void add(T item) {
+ if (item == null) {
+ return;
+ }
+ if (position < numSamples) {// reservoir not yet full, just append
+ samples[position] = item;
+ } else {// find a item to replace
+ int replaceIndex = rand.nextInt(position + 1);
+ if (replaceIndex < numSamples) {
+ samples[replaceIndex] = item;
+ }
+ }
+ position++;
+ }
+
+ public void clear() {
+ Arrays.fill(samples, null);
+ this.position = 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntIterator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/IntIterator.java b/core/src/main/java/hivemall/utils/stream/IntIterator.java
new file mode 100644
index 0000000..794d81e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/IntIterator.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+public interface IntIterator {
+
+ boolean hasNext();
+
+ int next();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntStream.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/IntStream.java b/core/src/main/java/hivemall/utils/stream/IntStream.java
new file mode 100644
index 0000000..4130177
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/IntStream.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+import javax.annotation.Nonnull;
+
+public interface IntStream {
+
+ @Nonnull
+ IntIterator iterator();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/StreamUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/StreamUtils.java b/core/src/main/java/hivemall/utils/stream/StreamUtils.java
new file mode 100644
index 0000000..7bd7b63
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/StreamUtils.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+import hivemall.utils.io.DeflaterOutputStream;
+import hivemall.utils.io.FastByteArrayInputStream;
+import hivemall.utils.io.FastMultiByteArrayOutputStream;
+import hivemall.utils.io.IOUtils;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.NoSuchElementException;
+import java.util.zip.Deflater;
+import java.util.zip.Inflater;
+import java.util.zip.InflaterInputStream;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class StreamUtils {
+
+ private StreamUtils() {}
+
+ @Nonnull
+ public static IntStream toCompressedIntStream(@Nonnull final int[] src) {
+ return toCompressedIntStream(src, Deflater.DEFAULT_COMPRESSION);
+ }
+
+ @Nonnull
+ public static IntStream toCompressedIntStream(@Nonnull final int[] src, final int level) {
+ FastMultiByteArrayOutputStream bos = new FastMultiByteArrayOutputStream(16384);
+ Deflater deflater = new Deflater(level, true);
+ DeflaterOutputStream defos = new DeflaterOutputStream(bos, deflater, 8192);
+ DataOutputStream dos = new DataOutputStream(defos);
+
+ final int count = src.length;
+ final byte[] compressed;
+ try {
+ for (int i = 0; i < count; i++) {
+ dos.writeInt(src[i]);
+ }
+ defos.finish();
+ compressed = bos.toByteArray_clear();
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to compress int[]", e);
+ } finally {
+ IOUtils.closeQuietly(dos);
+ }
+
+ return new InflateIntStream(compressed, count);
+ }
+
+ @Nonnull
+ public static IntStream toArrayIntStream(@Nonnull int[] array) {
+ return new ArrayIntStream(array);
+ }
+
+ static final class ArrayIntStream implements IntStream {
+
+ @Nonnull
+ private final int[] array;
+
+ ArrayIntStream(@Nonnull int[] array) {
+ this.array = array;
+ }
+
+ @Override
+ public ArrayIntIterator iterator() {
+ return new ArrayIntIterator(array);
+ }
+
+ }
+
+ static final class ArrayIntIterator implements IntIterator {
+
+ @Nonnull
+ private final int[] array;
+ @Nonnegative
+ private final int count;
+ @Nonnegative
+ private int index;
+
+ ArrayIntIterator(@Nonnull int[] array) {
+ this.array = array;
+ this.count = array.length;
+ this.index = 0;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < count;
+ }
+
+ @Override
+ public int next() {
+ if (index < count) {// hasNext()
+ return array[index++];
+ }
+ throw new NoSuchElementException();
+ }
+
+ }
+
+ static final class InflateIntStream implements IntStream {
+
+ @Nonnull
+ private final byte[] compressed;
+ @Nonnegative
+ private final int count;
+
+ InflateIntStream(@Nonnull byte[] compressed, @Nonnegative int count) {
+ this.compressed = compressed;
+ this.count = count;
+ }
+
+ @Override
+ public InflatedIntIterator iterator() {
+ FastByteArrayInputStream bis = new FastByteArrayInputStream(compressed);
+ InflaterInputStream infis = new InflaterInputStream(bis, new Inflater(true), 512);
+ DataInputStream in = new DataInputStream(infis);
+ return new InflatedIntIterator(in, count);
+ }
+
+ }
+
+ static final class InflatedIntIterator implements IntIterator {
+
+ @Nonnull
+ private final DataInputStream in;
+ @Nonnegative
+ private final int count;
+ @Nonnegative
+ private int index;
+
+ InflatedIntIterator(@Nonnull DataInputStream in, @Nonnegative int count) {
+ this.in = in;
+ this.count = count;
+ this.index = 0;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < count;
+ }
+
+ @Override
+ public int next() {
+ if (index < count) {// hasNext()
+ final int v;
+ try {
+ v = in.readInt();
+ } catch (IOException e) {
+ throw new IllegalStateException("Invalid input at " + index, e);
+ }
+ index++;
+ return v;
+ }
+ throw new NoSuchElementException();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
index a65a69a..076387f 100644
--- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
+++ b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
@@ -19,7 +19,7 @@
package hivemall.fm;
import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable;
import java.io.IOException;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
new file mode 100644
index 0000000..decd7df
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -0,0 +1,644 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.CSCMatrixBuilder;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.matrix.sparse.CSCMatrix;
+import hivemall.math.matrix.sparse.CSRMatrix;
+import hivemall.math.matrix.sparse.DoKMatrix;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class MatrixBuilderTest {
+
+ @Test
+ public void testReadOnlyCSRMatrix() {
+ Matrix matrix = csrMatrix();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyCSRMatrixFromLibSVM() {
+ Matrix matrix = csrMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyCSRMatrixNoRow() {
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ Matrix matrix = builder.buildMatrix();
+ Assert.assertEquals(0, matrix.numRows());
+ Assert.assertEquals(0, matrix.numColumns());
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyCSRMatrixGetFail1() {
+ Matrix matrix = csrMatrix();
+ matrix.get(7, 5);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyCSRMatrixGetFail2() {
+ Matrix matrix = csrMatrix();
+ matrix.get(6, 7);
+ }
+
+ @Test
+ public void testCSCMatrixFromLibSVM() {
+ CSCMatrix matrix = cscMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testCSC2CSR() {
+ CSCMatrix csc = cscMatrixFromLibSVM();
+ RowMajorMatrix csr = csc.toRowMajorMatrix();
+ Assert.assertTrue(csr instanceof CSRMatrix);
+ Assert.assertEquals(6, csr.numRows());
+ Assert.assertEquals(6, csr.numColumns());
+ Assert.assertEquals(4, csr.numColumns(0));
+ Assert.assertEquals(2, csr.numColumns(1));
+ Assert.assertEquals(4, csr.numColumns(2));
+ Assert.assertEquals(2, csr.numColumns(3));
+ Assert.assertEquals(1, csr.numColumns(4));
+ Assert.assertEquals(1, csr.numColumns(5));
+
+ Assert.assertEquals(11d, csr.get(0, 0), 0.d);
+ Assert.assertEquals(12d, csr.get(0, 1), 0.d);
+ Assert.assertEquals(13d, csr.get(0, 2), 0.d);
+ Assert.assertEquals(14d, csr.get(0, 3), 0.d);
+ Assert.assertEquals(22d, csr.get(1, 1), 0.d);
+ Assert.assertEquals(23d, csr.get(1, 2), 0.d);
+ Assert.assertEquals(33d, csr.get(2, 2), 0.d);
+ Assert.assertEquals(34d, csr.get(2, 3), 0.d);
+ Assert.assertEquals(35d, csr.get(2, 4), 0.d);
+ Assert.assertEquals(36d, csr.get(2, 5), 0.d);
+ Assert.assertEquals(44d, csr.get(3, 3), 0.d);
+ Assert.assertEquals(45d, csr.get(3, 4), 0.d);
+ Assert.assertEquals(56d, csr.get(4, 5), 0.d);
+ Assert.assertEquals(66d, csr.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, csr.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, csr.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, csr.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testCSC2CSR2CSR() {
+ CSCMatrix csc = cscMatrixFromLibSVM();
+ CSCMatrix csc2 = csc.toRowMajorMatrix().toColumnMajorMatrix();
+ Assert.assertEquals(csc.nnz(), csc2.nnz());
+ Assert.assertEquals(6, csc2.numRows());
+ Assert.assertEquals(6, csc2.numColumns());
+ Assert.assertEquals(4, csc2.numColumns(0));
+ Assert.assertEquals(2, csc2.numColumns(1));
+ Assert.assertEquals(4, csc2.numColumns(2));
+ Assert.assertEquals(2, csc2.numColumns(3));
+ Assert.assertEquals(1, csc2.numColumns(4));
+ Assert.assertEquals(1, csc2.numColumns(5));
+
+ Assert.assertEquals(11d, csc2.get(0, 0), 0.d);
+ Assert.assertEquals(12d, csc2.get(0, 1), 0.d);
+ Assert.assertEquals(13d, csc2.get(0, 2), 0.d);
+ Assert.assertEquals(14d, csc2.get(0, 3), 0.d);
+ Assert.assertEquals(22d, csc2.get(1, 1), 0.d);
+ Assert.assertEquals(23d, csc2.get(1, 2), 0.d);
+ Assert.assertEquals(33d, csc2.get(2, 2), 0.d);
+ Assert.assertEquals(34d, csc2.get(2, 3), 0.d);
+ Assert.assertEquals(35d, csc2.get(2, 4), 0.d);
+ Assert.assertEquals(36d, csc2.get(2, 5), 0.d);
+ Assert.assertEquals(44d, csc2.get(3, 3), 0.d);
+ Assert.assertEquals(45d, csc2.get(3, 4), 0.d);
+ Assert.assertEquals(56d, csc2.get(4, 5), 0.d);
+ Assert.assertEquals(66d, csc2.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, csc2.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, csc2.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d);
+ }
+
+
+ @Test
+ public void testDoKMatrixFromLibSVM() {
+ Matrix matrix = dokMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2d() {
+ Matrix matrix = rowMajorDenseMatrix();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dSparseInput() {
+ Matrix matrix = denseMatrixSparseInput();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dFromLibSVM() {
+ Matrix matrix = denseMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dNoRow() {
+ Matrix matrix = new RowMajorDenseMatrixBuilder(1024).buildMatrix();
+ Assert.assertEquals(0, matrix.numRows());
+ Assert.assertEquals(0, matrix.numColumns());
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dFailOutOfBound1() {
+ Matrix matrix = rowMajorDenseMatrix();
+ matrix.get(7, 5);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dFailOutOfBound2() {
+ Matrix matrix = rowMajorDenseMatrix();
+ matrix.get(6, 7);
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrix2d() {
+ ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
+
+ Assert.assertEquals(6, colMatrix.numRows());
+ Assert.assertEquals(6, colMatrix.numColumns());
+ Assert.assertEquals(4, colMatrix.numColumns(0));
+ Assert.assertEquals(2, colMatrix.numColumns(1));
+ Assert.assertEquals(4, colMatrix.numColumns(2));
+ Assert.assertEquals(2, colMatrix.numColumns(3));
+ Assert.assertEquals(1, colMatrix.numColumns(4));
+ Assert.assertEquals(1, colMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testDenseMatrixColumnMajor2RowMajor() {
+ ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
+ RowMajorDenseMatrix2d rowMatrix = colMatrix.toRowMajorMatrix();
+
+ Assert.assertEquals(6, rowMatrix.numRows());
+ Assert.assertEquals(6, rowMatrix.numColumns());
+ Assert.assertEquals(4, rowMatrix.numColumns(0));
+ Assert.assertEquals(3, rowMatrix.numColumns(1));
+ Assert.assertEquals(6, rowMatrix.numColumns(2));
+ Assert.assertEquals(5, rowMatrix.numColumns(3));
+ Assert.assertEquals(6, rowMatrix.numColumns(4));
+ Assert.assertEquals(6, rowMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, rowMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, rowMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, rowMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, rowMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, rowMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, rowMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, rowMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, rowMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, rowMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, rowMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, rowMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, rowMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, rowMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, rowMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, rowMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, rowMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d);
+
+ // convert back to column major matrix
+
+ colMatrix = rowMatrix.toColumnMajorMatrix();
+
+ Assert.assertEquals(6, colMatrix.numRows());
+ Assert.assertEquals(6, colMatrix.numColumns());
+ Assert.assertEquals(4, colMatrix.numColumns(0));
+ Assert.assertEquals(2, colMatrix.numColumns(1));
+ Assert.assertEquals(4, colMatrix.numColumns(2));
+ Assert.assertEquals(2, colMatrix.numColumns(3));
+ Assert.assertEquals(1, colMatrix.numColumns(4));
+ Assert.assertEquals(1, colMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testCSRMatrixNullRow() {
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextRow();
+ builder.nextColumn(3, 66).nextRow();
+ Matrix matrix = builder.buildMatrix();
+ Assert.assertEquals(4, matrix.numRows());
+ }
+
+ private static CSRMatrix csrMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
+ builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
+ builder.nextColumn(5, 56).nextRow();
+ builder.nextColumn(5, 66).nextRow();
+ return builder.buildMatrix();
+ }
+
+ private static CSRMatrix csrMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+ private static CSCMatrix cscMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSCMatrixBuilder builder = new CSCMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+
+ private static DoKMatrix dokMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ DoKMatrixBuilder builder = new DoKMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d rowMajorDenseMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new double[] {11, 12, 13, 14});
+ builder.nextRow(new double[] {0, 22, 23});
+ builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
+ builder.nextRow(new double[] {0, 0, 0, 44, 45});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
+ return builder.buildMatrix();
+ }
+
+ private static ColumnMajorDenseMatrix2d columnMajorDenseMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ ColumnMajorDenseMatrixBuilder builder = new ColumnMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new double[] {11, 12, 13, 14});
+ builder.nextRow(new double[] {0, 22, 23});
+ builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
+ builder.nextRow(new double[] {0, 0, 0, 44, 45});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d denseMatrixSparseInput() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
+ builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
+ builder.nextColumn(5, 56).nextRow();
+ builder.nextColumn(5, 66).nextRow();
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d denseMatrixFromLibSVM() {
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
new file mode 100644
index 0000000..f6a52fe
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d;
+import hivemall.math.matrix.ints.DoKIntMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntMatrixTest {
+
+ @Test
+ public void testDoKMatrixRowMajor() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, true);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajor() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, true);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajorNonZeroOnlyFalse() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(0, matrix.get(1, 3, -1));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+
+ matrix.setDefaultValue(-1);
+ Assert.assertEquals(-1, matrix.get(5, 4));
+ Assert.assertEquals(-1, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(0, matrix.get(1, 0, 0));
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrix() {
+ ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6);
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrixEachColumn() {
+ ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+
+ // change default value to zero
+ matrix.setDefaultValue(0);
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajorNonZeroOnlyFalseEachColumn() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+
+ // change default value to zero
+ matrix.setDefaultValue(0);
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+ }
+
+ @Test
+ public void testDoKMatrixRowMajorNonZeroOnlyFalseEachColumn() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, false);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(4 + 3 + 6 + 5 + 6 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachNonZeroInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(4 + 2 + 4 + 2 + 1 + 1, count.getValue());
+ }
+
+ private static int[][] rowMajorData() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ int[][] data = new int[6][];
+ data[0] = new int[] {11, 12, 13, 14};
+ data[1] = new int[] {0, 22, 23};
+ data[2] = new int[] {0, 0, 33, 34, 35, 36};
+ data[3] = new int[] {0, 0, 0, 44, 45};
+ data[4] = new int[] {0, 0, 0, 0, 0, 56};
+ data[5] = new int[] {0, 0, 0, 0, 0, 66};
+ return data;
+ }
+
+ private static int[][] columnMajorData() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ int[][] data = new int[6][];
+ data[0] = new int[] {11};
+ data[1] = new int[] {12, 22};
+ data[2] = new int[] {13, 23, 33};
+ data[3] = new int[] {14, 0, 34, 44};
+ data[4] = new int[] {0, 0, 35, 45};
+ data[5] = new int[] {0, 0, 36, 0, 56, 66};
+ return data;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
deleted file mode 100644
index 5545631..0000000
--- a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class MatrixBuilderTest {
-
- @Test
- public void testReadOnlyCSRMatrix() {
- Matrix matrix = csrMatrix();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(2, matrix.numColumns(1));
- Assert.assertEquals(4, matrix.numColumns(2));
- Assert.assertEquals(2, matrix.numColumns(3));
- Assert.assertEquals(1, matrix.numColumns(4));
- Assert.assertEquals(1, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
- Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
-
- matrix.setDefaultValue(Double.NaN);
- Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d);
- }
-
- @Test
- public void testReadOnlyCSRMatrixFromLibSVM() {
- Matrix matrix = csrMatrixFromLibSVM();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(2, matrix.numColumns(1));
- Assert.assertEquals(4, matrix.numColumns(2));
- Assert.assertEquals(2, matrix.numColumns(3));
- Assert.assertEquals(1, matrix.numColumns(4));
- Assert.assertEquals(1, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
- Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
-
- matrix.setDefaultValue(Double.NaN);
- Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d);
- }
-
-
- @Test
- public void testReadOnlyCSRMatrixNoRow() {
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- Matrix matrix = builder.buildMatrix(true);
- Assert.assertEquals(0, matrix.numRows());
- Assert.assertEquals(0, matrix.numColumns());
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyCSRMatrixGetFail1() {
- Matrix matrix = csrMatrix();
- matrix.get(7, 5);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyCSRMatrixGetFail2() {
- Matrix matrix = csrMatrix();
- matrix.get(6, 7);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2d() {
- Matrix matrix = denseMatrix();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dSparseInput() {
- Matrix matrix = denseMatrixSparseInput();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dFromLibSVM() {
- Matrix matrix = denseMatrixFromLibSVM();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dNoRow() {
- Matrix matrix = new DenseMatrixBuilder(1024).buildMatrix(true);
- Assert.assertEquals(0, matrix.numRows());
- Assert.assertEquals(0, matrix.numColumns());
- }
-
- @Test(expected = UnsupportedOperationException.class)
- public void testReadOnlyDenseMatrix2dFailToChangeDefaultValue() {
- Matrix matrix = denseMatrix();
- matrix.setDefaultValue(Double.NaN);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyDenseMatrix2dFailOutOfBound1() {
- Matrix matrix = denseMatrix();
- matrix.get(7, 5);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyDenseMatrix2dFailOutOfBound2() {
- Matrix matrix = denseMatrix();
- matrix.get(6, 7);
- }
-
- private static Matrix csrMatrix() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
- builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
- builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
- builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
- builder.nextColumn(5, 56).nextRow();
- builder.nextColumn(5, 66).nextRow();
- return builder.buildMatrix(true);
- }
-
- private static Matrix csrMatrixFromLibSVM() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
- builder.nextRow(new String[] {"1:22", "2:23"});
- builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
- builder.nextRow(new String[] {"3:44", "4:45"});
- builder.nextRow(new String[] {"5:56"});
- builder.nextRow(new String[] {"5:66"});
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrix() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextRow(new double[] {11, 12, 13, 14});
- builder.nextRow(new double[] {0, 22, 23});
- builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
- builder.nextRow(new double[] {0, 0, 0, 44, 45});
- builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
- builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrixSparseInput() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
- builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
- builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
- builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
- builder.nextColumn(5, 56).nextRow();
- builder.nextColumn(5, 66).nextRow();
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrixFromLibSVM() {
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
- builder.nextRow(new String[] {"1:22", "2:23"});
- builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
- builder.nextRow(new String[] {"3:44", "4:45"});
- builder.nextRow(new String[] {"5:56"});
- builder.nextRow(new String[] {"5:66"});
- return builder.buildMatrix(true);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index 3c6116c..bb6de6b 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -19,13 +19,13 @@
package hivemall.smile.classification;
import static org.junit.Assert.assertEquals;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.classification.DecisionTree.Node;
import hivemall.smile.data.Attribute;
-import hivemall.smile.tools.TreePredictUDF;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
-import hivemall.utils.lang.ArrayUtils;
import java.io.BufferedInputStream;
import java.io.IOException;
@@ -33,14 +33,9 @@ import java.io.InputStream;
import java.net.URL;
import java.text.ParseException;
+import javax.annotation.Nonnull;
+
import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.IntWritable;
import org.junit.Assert;
import org.junit.Test;
@@ -52,85 +47,76 @@ import smile.validation.LOOCV;
public class DecisionTreeTest {
private static final boolean DEBUG = false;
- /**
- * Test of learn method, of class DecisionTree.
- *
- * @throws ParseException
- * @throws IOException
- */
@Test
public void testWeather() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff");
- InputStream is = new BufferedInputStream(url.openStream());
-
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
-
- AttributeDataset weather = arffParser.parse(is);
- double[][] x = weather.toArray(new double[weather.size()][]);
- int[] y = weather.toArray(new int[weather.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- int error = 0;
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
+ int responseIndex = 4;
+ int numLeafs = 3;
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(weather.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 3);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
- error++;
- }
+ // dense matrix
+ int error = run(
+ "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(5, error);
- debugPrint("Decision Tree error = " + error);
+ // sparse matrix
+ error = run(
+ "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
+ responseIndex, numLeafs, false);
assertEquals(5, error);
}
@Test
public void testIris() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
- InputStream is = new BufferedInputStream(url.openStream());
-
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
-
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- int error = 0;
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
-
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- smile.math.Random rand = new smile.math.Random(i);
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, Integer.MAX_VALUE, rand);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
- error++;
- }
+ int responseIndex = 4;
+ int numLeafs = Integer.MAX_VALUE;
+ int error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(8, error);
- debugPrint("Decision Tree error = " + error);
+ // sparse
+ error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, false);
assertEquals(8, error);
}
@Test
+ public void testIrisSparseDenseEquals() throws IOException, ParseException {
+ int responseIndex = 4;
+ int numLeafs = Integer.MAX_VALUE;
+ runAndCompareSparseAndDense(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs);
+ }
+
+ @Test
public void testIrisDepth4() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ int responseIndex = 4;
+ int numLeafs = 4;
+ int error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(7, error);
+
+ // sparse
+ error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, false);
+ assertEquals(7, error);
+ }
+
+ private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
+ throws IOException, ParseException {
+ URL url = new URL(datasetUrl);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
+ arffParser.setResponseIndex(responseIndex);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
+ AttributeDataset ds = arffParser.parse(is);
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
@@ -139,52 +125,29 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
+ Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
error++;
+ }
}
debugPrint("Decision Tree error = " + error);
- assertEquals(7, error);
+ return error;
}
- @Test
- public void testIrisStackmachine() throws IOException, ParseException, HiveException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex,
+ int numLeafs) throws IOException, ParseException {
+ URL url = new URL(datasetUrl);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
-
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]),
- predictByStackMachine(tree, x[loocv.test[i]]));
- }
- }
-
- @Test
- public void testIrisJavascript() throws IOException, ParseException, HiveException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
- InputStream is = new BufferedInputStream(url.openStream());
+ arffParser.setResponseIndex(responseIndex);
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
+ AttributeDataset ds = arffParser.parse(is);
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
@@ -192,10 +155,12 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]),
- predictByJavascript(tree, x[loocv.test[i]]));
+ Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
}
}
@@ -218,7 +183,7 @@ public class DecisionTreeTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b = tree.predictSerCodegen(false);
Node node = DecisionTree.deserializeNode(b, b.length, false);
@@ -245,7 +210,7 @@ public class DecisionTreeTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b1 = tree.predictSerCodegen(true);
byte[] b2 = tree.predictSerCodegen(false);
@@ -256,52 +221,18 @@ public class DecisionTreeTest {
}
}
- private static int predictByStackMachine(DecisionTree tree, double[] x) throws HiveException,
- IOException {
- String script = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(script);
-
- TreePredictUDF udf = new TreePredictUDF();
- udf.initialize(new ObjectInspector[] {
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
- ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
- DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(script),
- new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)};
-
- IntWritable result = (IntWritable) udf.evaluate(arguments);
- result = (IntWritable) udf.evaluate(arguments);
- udf.close();
- return result.get();
- }
-
- private static int predictByJavascript(DecisionTree tree, double[] x) throws HiveException,
- IOException {
- String script = tree.predictJsCodegen();
- debugPrint(script);
-
- TreePredictUDF udf = new TreePredictUDF();
- udf.initialize(new ObjectInspector[] {
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
- ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
-
- DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.javascript.getId()),
- new DeferredJavaObject(script), new DeferredJavaObject(ArrayUtils.toList(x)),
- new DeferredJavaObject(true)};
-
- IntWritable result = (IntWritable) udf.evaluate(arguments);
- result = (IntWritable) udf.evaluate(arguments);
- udf.close();
- return result.get();
+ @Nonnull
+ private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
+ if (dense) {
+ return new RowMajorDenseMatrix2d(x, x[0].length);
+ } else {
+ int numRows = x.length;
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ for (int i = 0; i < numRows; i++) {
+ builder.nextRow(x[i]);
+ }
+ return builder.buildMatrix();
+ }
}
private static void debugPrint(String msg) {