You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:23 UTC
[11/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
new file mode 100644
index 0000000..d028d47
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class ColumnMajorDenseIntMatrix2d extends ColumnMajorIntMatrix {
+
+ @Nonnull
+ private final int[][] data; // col-row
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+
+ public ColumnMajorDenseIntMatrix2d(@Nonnull int[][] data, @Nonnegative int numRows) {
+ super();
+ this.data = data;
+ this.numRows = numRows;
+ this.numColumns = data.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int[] getRow(final int index) {
+ checkRowIndex(index, numRows);
+
+ int[] row = new int[numColumns];
+ return getRow(index, row);
+ }
+
+ @Override
+ public int[] getRow(final int index, @Nonnull final int[] dst) {
+ checkRowIndex(index, numRows);
+
+ for (int j = 0; j < data.length; j++) {
+ final int[] col = data[j];
+ if (index < col.length) {
+ dst[j] = col[index];
+ }
+ }
+ return dst;
+ }
+
+ @Override
+ public int get(final int row, final int col, final int defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ if (row >= colData.length) {
+ return defaultValue;
+ }
+ return colData[row];
+ }
+
+ @Override
+ public int getAndSet(final int row, final int col, final int value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+
+ final int old = colData[row];
+ colData[row] = value;
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final int value) {
+ checkIndex(row, col, numRows, numColumns);
+ if (value == 0) {
+ return;
+ }
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+ colData[row] = value;
+ }
+
+ @Override
+ public void incr(final int row, final int col, final int delta) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+
+ colData[row] += delta;
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final int[] colData = data[col];
+ if (colData == null) {
+ if (nullOutput) {
+ for (int i = 0; i < numRows; i++) {
+ procedure.apply(i, defaultValue);
+ }
+ }
+ return;
+ }
+
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ procedure.apply(row, colData[row]);
+ }
+ if (nullOutput) {
+ for (; row < numRows; row++) {
+ procedure.apply(row, defaultValue);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final int[] colData = data[col];
+ if (colData == null) {
+ return;
+ }
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ final int v = colData[row];
+ if (v != 0) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
new file mode 100644
index 0000000..e0b3b4b
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix {
+
+ public ColumnMajorIntMatrix() {
+ super();
+ }
+
+ @Override
+ public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachNonZeroInRow(int row, VectorProcedure procedure) {
+ throw new UnsupportedOperationException();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
new file mode 100644
index 0000000..2bbd3b4
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2IntOpenHashTable;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Dictionary-of-Key Sparse Int Matrix.
+ */
+public final class DoKIntMatrix extends AbstractIntMatrix {
+
+ @Nonnull
+ private final Long2IntOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+
+ public DoKIntMatrix() {
+ this(0, 0);
+ }
+
+ public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
+ @Nonnegative float sparsity) {
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2IntOpenHashTable(initialCapacity);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ }
+
+ private DoKIntMatrix(@Nonnull Long2IntOpenHashTable elements, @Nonnegative int numRows,
+ @Nonnegative int numColumns) {
+ this.elements = elements;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int[] getRow(@Nonnegative final int index) {
+ int[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public int[] getRow(@Nonnegative final int row, @Nonnull final int[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long index = index(row, col);
+ int v = elements.get(index, defaultValue);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public int get(@Nonnegative final int row, @Nonnegative final int col, final int defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final int value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+
+ @Override
+ public int getAndSet(@Nonnegative final int row, @Nonnegative final int col, final int value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ int old = elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ return old;
+ }
+
+ @Override
+ public void incr(@Nonnegative final int row, @Nonnegative final int col, final int delta) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ elements.incr(index, delta);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, defaultValue);
+ }
+ } else {
+ int v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int v = elements.get(i, 0);
+ if (v != 0) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, defaultValue);
+ }
+ } else {
+ int v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int v = elements.get(i, 0);
+ if (v != 0) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+ @Nonnull
+ public static DoKIntMatrix build(@Nonnull final int[][] matrix, boolean rowMajorInput,
+ boolean nonZeroOnly) {
+ if (rowMajorInput) {
+ return buildFromRowMajorMatrix(matrix, nonZeroOnly);
+ } else {
+ return buildFromColumnMajorMatrix(matrix, nonZeroOnly);
+ }
+ }
+
+ @Nonnull
+ private static DoKIntMatrix buildFromRowMajorMatrix(@Nonnull final int[][] rowMajorMatrix,
+ boolean nonZeroOnly) {
+ final Long2IntOpenHashTable elements = new Long2IntOpenHashTable(rowMajorMatrix.length * 3);
+
+ int numRows = rowMajorMatrix.length, numColumns = 0;
+ for (int i = 0; i < rowMajorMatrix.length; i++) {
+ final int[] row = rowMajorMatrix[i];
+ if (row == null) {
+ continue;
+ }
+ numColumns = Math.max(numColumns, row.length);
+ for (int col = 0; col < row.length; col++) {
+ int value = row[col];
+ if (nonZeroOnly && value == 0) {
+ continue;
+ }
+ long index = index(i, col);
+ elements.put(index, value);
+ }
+ }
+
+ return new DoKIntMatrix(elements, numRows, numColumns);
+ }
+
+ @Nonnull
+ private static DoKIntMatrix buildFromColumnMajorMatrix(
+ @Nonnull final int[][] columnMajorMatrix, boolean nonZeroOnly) {
+ final Long2IntOpenHashTable elements = new Long2IntOpenHashTable(
+ columnMajorMatrix.length * 3);
+
+ int numRows = 0, numColumns = columnMajorMatrix.length;
+ for (int j = 0; j < columnMajorMatrix.length; j++) {
+ final int[] col = columnMajorMatrix[j];
+ if (col == null) {
+ continue;
+ }
+ numRows = Math.max(numRows, col.length);
+ for (int row = 0; row < col.length; row++) {
+ int value = col[row];
+ if (nonZeroOnly && value == 0) {
+ continue;
+ }
+ long index = index(row, j);
+ elements.put(index, value);
+ }
+ }
+
+ return new DoKIntMatrix(elements, numRows, numColumns);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
new file mode 100644
index 0000000..bcc954e
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public interface IntMatrix {
+
+ public boolean isSparse();
+
+ public boolean readOnly();
+
+ public void setDefaultValue(int value);
+
+ @Nonnegative
+ public int numRows();
+
+ @Nonnegative
+ public int numColumns();
+
+ @Nonnull
+ public int[] row();
+
+ @Nonnull
+ public int[] getRow(@Nonnegative int index);
+
+ /**
+ * @return returns dst
+ */
+ @Nonnull
+ public int[] getRow(@Nonnegative int index, @Nonnull int[] dst);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public int get(@Nonnegative int row, @Nonnegative int col);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public int get(@Nonnegative int row, @Nonnegative int col, int defaultValue);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative int row, @Nonnegative int col, int value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public int getAndSet(@Nonnegative int row, @Nonnegative int col, int value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void incr(@Nonnegative int row, @Nonnegative int col);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void incr(@Nonnegative int row, @Nonnegative int col, int delta);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
new file mode 100644
index 0000000..d2232b2
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.builders.CSCMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
+ */
+public final class CSCMatrix extends ColumnMajorMatrix {
+
+ @Nonnull
+ private final int[] columnPointers;
+ @Nonnull
+ private final int[] rowIndicies;
+ @Nonnull
+ private final double[] values;
+
+ private final int numRows;
+ private final int numColumns;
+ private final int nnz;
+
+ public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndicies,
+ @Nonnull double[] values, int numRows, int numColumns) {
+ super();
+ Preconditions.checkArgument(columnPointers.length >= 1,
+ "rowPointers must be greather than 0: " + columnPointers.length);
+ Preconditions.checkArgument(rowIndicies.length == values.length, "#rowIndicies ("
+ + rowIndicies.length + ") must be equals to #values (" + values.length + ")");
+ this.columnPointers = columnPointers;
+ this.rowIndicies = rowIndicies;
+ this.values = values;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(final int row) {
+ checkRowIndex(row, numRows);
+
+ return ArrayUtils.count(rowIndicies, row);
+ }
+
+ @Override
+ public double[] getRow(int index) {
+ checkRowIndex(index, numRows);
+
+ final double[] row = new double[numColumns];
+
+ final int numCols = columnPointers.length - 1;
+ for (int j = 0; j < numCols; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ row[j] = values[k];
+ }
+ }
+
+ return row;
+ }
+
+ @Override
+ public double[] getRow(final int index, @Nonnull final double[] dst) {
+ checkRowIndex(index, numRows);
+
+ final int last = Math.min(dst.length, columnPointers.length - 1);
+ for (int j = 0; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ dst[j] = values[k];
+ }
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int j = 0, last = columnPointers.length - 1; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ double v = values[k];
+ row.set(j, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(final int row, final int col, final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public double getAndSet(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ double old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = columnPointers[col];
+ int rightEx = columnPointers[col + 1];
+ final int index = Arrays.binarySearch(rowIndicies, leftIn, rightEx, row);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(final int row1, final int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+
+ if (nullOutput) {
+ for (int row = 0, i = startIn; row < numRows; row++) {
+ if (i < endEx && row == rowIndicies[i]) {
+ double v = values[i++];
+ procedure.apply(row, v);
+ } else {
+ procedure.apply(row, 0.d);
+ }
+ }
+ } else {
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndicies[j];
+ double v = values[j];
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndicies[j];
+ final double v = values[j];
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public CSRMatrix toRowMajorMatrix() {
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndicies = new int[nnz];
+ final double[] csrValues = new double[nnz];
+
+ // compute nnz per for each row
+ for (int i = 0; i < rowIndicies.length; i++) {
+ rowPointers[rowIndicies[i]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ for (int j = 0; j < numColumns; j++) {
+ for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) {
+ int col = rowIndicies[i];
+ int dst = rowPointers[col];
+
+ colIndicies[dst] = j;
+ csrValues[dst] = values[i];
+
+ rowPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+
+ return new CSRMatrix(rowPointers, colIndicies, csrValues, numColumns);
+ }
+
+ @Override
+ public CSCMatrixBuilder builder() {
+ return new CSCMatrixBuilder(nnz);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
new file mode 100644
index 0000000..dd89521
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Read-only CSR double Matrix.
+ *
+ * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
+ * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
+ */
+public final class CSRMatrix extends RowMajorMatrix {
+
+ @Nonnull
+ private final int[] rowPointers;
+ @Nonnull
+ private final int[] columnIndices;
+ @Nonnull
+ private final double[] values;
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+ @Nonnegative
+ private final int nnz;
+
+ public CSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices,
+ @Nonnull double[] values, @Nonnegative int numColumns) {
+ super();
+ Preconditions.checkArgument(rowPointers.length >= 1,
+ "rowPointers must be greather than 0: " + rowPointers.length);
+ Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices ("
+ + columnIndices.length + ") must be equals to #values (" + values.length + ")");
+ this.rowPointers = rowPointers;
+ this.columnIndices = columnIndices;
+ this.values = values;
+ this.numRows = rowPointers.length - 1;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ checkRowIndex(row, numRows);
+
+ int columns = rowPointers[row + 1] - rowPointers[row];
+ return columns;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ final double[] row = new double[numColumns];
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, double value) {
+ row[col] = value;
+ }
+ });
+ return row;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index, @Nonnull final double[] dst) {
+ Arrays.fill(dst, 0.d);
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, double value) {
+ checkColIndex(col, numColumns);
+ dst[col] = value;
+ }
+ });
+ return dst;
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ double old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = rowPointers[row];
+ int rightEx = rowPointers[row + 1];
+ final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(int row1, int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ if (nullOutput) {
+ for (int col = 0, j = startIn; col < numColumns; col++) {
+ if (j < endEx && col == columnIndices[j]) {
+ double v = values[j++];
+ procedure.apply(col, v);
+ } else {
+ procedure.apply(col, 0.d);
+ }
+ }
+ } else {
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i], values[i]);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+ for (int i = startIn; i < endEx; i++) {
+ int col = columnIndices[i];
+ final double v = values[i];
+ if (v != 0.d) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i]);
+ }
+ }
+
+ @Nonnull
+ public CSCMatrix toColumnMajorMatrix() {
+ final int[] columnPointers = new int[numColumns + 1];
+ final int[] rowIndicies = new int[nnz];
+ final double[] cscValues = new double[nnz];
+
+ // compute nnz per for each column
+ for (int j = 0; j < columnIndices.length; j++) {
+ columnPointers[columnIndices[j]]++;
+ }
+ for (int j = 0, sum = 0; j < numColumns; j++) {
+ int curr = columnPointers[j];
+ columnPointers[j] = sum;
+ sum += curr;
+ }
+ columnPointers[numColumns] = nnz;
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = rowPointers[i], last = rowPointers[i + 1]; j < last; j++) {
+ int col = columnIndices[j];
+ int dst = columnPointers[col];
+
+ rowIndicies[dst] = i;
+ cscValues[dst] = values[j];
+
+ columnPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int j = 0, last = 0; j <= numColumns; j++) {
+ int tmp = columnPointers[j];
+ columnPointers[j] = last;
+ last = tmp;
+ }
+
+ return new CSCMatrix(columnPointers, rowIndicies, cscValues, numRows, numColumns);
+ }
+
+ @Override
+ public CSRMatrixBuilder builder() {
+ return new CSRMatrixBuilder(values.length);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
new file mode 100644
index 0000000..bcfd152
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.annotations.Experimental;
+import hivemall.math.matrix.AbstractMatrix;
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2DoubleOpenHashTable;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+@Experimental
+public final class DoKMatrix extends AbstractMatrix {
+
+ @Nonnull
+ private final Long2DoubleOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public DoKMatrix() {
+ this(0, 0);
+ }
+
+ public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols, @Nonnegative float sparsity) {
+ super();
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2DoubleOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.d);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ this.nnz = 0;
+ }
+
+ public DoKMatrix(@Nonnegative int initSize) {
+ super();
+ int initialCapacity = Math.max(initSize, 16384);
+ this.elements = new Long2DoubleOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.d);
+ this.numRows = 0;
+ this.numColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ int count = 0;
+ for (int j = 0; j < numColumns; j++) {
+ long index = index(row, j);
+ if (elements.containsKey(index)) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ double[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ double v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int col = 0; col < numColumns; col++) {
+ long k = index(index, col);
+ final double v = elements.get(k, 0.d);
+ if (v != 0.d) {
+ row.set(col, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ checkIndex(row, col);
+
+ if (value == 0.d) {
+ return;
+ }
+
+ long index = index(row, col);
+ if (elements.put(index, value, 0.d) == 0.d) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ double old = elements.put(index, value, 0.d);
+ if (old == 0.d) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ return old;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ for (int j = 0; j < numColumns; j++) {
+ final long i1 = index(row1, j);
+ final long i2 = index(row2, j);
+
+ final int k1 = elements._findKey(i1);
+ final int k2 = elements._findKey(i2);
+
+ if (k1 >= 0) {
+ if (k2 >= 0) {
+ double v1 = elements._get(k1);
+ double v2 = elements._set(k2, v1);
+ elements._set(k1, v2);
+ } else {// k1>=0 and k2<0
+ double v1 = elements._remove(k1);
+ elements.put(i2, v1);
+ }
+ } else if (k2 >= 0) {// k2>=0 and k1 < 0
+ double v2 = elements._remove(k2);
+ elements.put(i1, v2);
+ } else {//k1<0 and k2<0
+ continue;
+ }
+ }
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, 0.d);
+ }
+ } else {
+ double v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final double v = elements.get(i, 0.d);
+ if (v != 0.d) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key != -1) {
+ procedure.apply(col);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, 0.d);
+ }
+ } else {
+ double v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final double v = elements.get(i, 0.d);
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public RowMajorMatrix toRowMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public ColumnMajorMatrix toColumnMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public DoKMatrixBuilder builder() {
+ return new DoKMatrixBuilder(elements.size());
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/CommonsMathRandom.java b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
new file mode 100644
index 0000000..e0b7554
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.math3.random.MersenneTwister;
+import org.apache.commons.math3.random.RandomGenerator;
+
+public final class CommonsMathRandom implements PRNG {
+
+ @Nonnull
+ private final RandomGenerator rng;
+
+ public CommonsMathRandom() {
+ this.rng = new MersenneTwister();
+ }
+
+ public CommonsMathRandom(long seed) {
+ this.rng = new MersenneTwister(seed);
+ }
+
+ public CommonsMathRandom(@Nonnull RandomGenerator rng) {
+ this.rng = rng;
+ }
+
+ @Override
+ public int nextInt(final int n) {
+ return rng.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rng.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rng.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rng.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/JavaRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/JavaRandom.java b/core/src/main/java/hivemall/math/random/JavaRandom.java
new file mode 100644
index 0000000..f0ed4c7
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/JavaRandom.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+
+public final class JavaRandom implements PRNG {
+
+ private final Random rand;
+
+ public JavaRandom() {
+ this.rand = new Random();
+ }
+
+ public JavaRandom(long seed) {
+ this.rand = new Random(seed);
+ }
+
+ public JavaRandom(@Nonnull Random rand) {
+ this.rand = rand;
+ }
+
+ @Override
+ public int nextInt(int n) {
+ return rand.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rand.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rand.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rand.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/PRNG.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/PRNG.java b/core/src/main/java/hivemall/math/random/PRNG.java
new file mode 100644
index 0000000..d42dcfb
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/PRNG.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnegative;
+
+/**
+ * @link https://en.wikipedia.org/wiki/Pseudorandom_number_generator
+ */
+public interface PRNG {
+
+ /**
+ * Returns a random integer in [0, n).
+ */
+ public int nextInt(@Nonnegative int n);
+
+ public int nextInt();
+
+ public long nextLong();
+
+ public double nextDouble();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
new file mode 100644
index 0000000..8843f7e
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import hivemall.utils.lang.Primitives;
+
+import java.security.SecureRandom;
+
+import javax.annotation.Nonnull;
+
+public final class RandomNumberGeneratorFactory {
+
+ private RandomNumberGeneratorFactory() {}
+
+ @Nonnull
+ public static PRNG createPRNG() {
+ return createPRNG(PRNGType.smile);
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(long seed) {
+ return createPRNG(PRNGType.smile, seed);
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(@Nonnull PRNGType type) {
+ final PRNG rng;
+ switch (type) {
+ case java:
+ rng = new JavaRandom();
+ break;
+ case secure:
+ rng = new JavaRandom(new SecureRandom());
+ break;
+ case smile:
+ rng = new SmileRandom();
+ break;
+ case smileMT:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister());
+ break;
+ case smileMT64:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister64());
+ break;
+ case commonsMath3MT:
+ rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister());
+ break;
+ default:
+ throw new IllegalStateException("Unexpected type: " + type);
+ }
+ return rng;
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(@Nonnull PRNGType type, long seed) {
+ final PRNG rng;
+ switch (type) {
+ case java:
+ rng = new JavaRandom(seed);
+ break;
+ case secure:
+ rng = new JavaRandom(new SecureRandom(Primitives.toBytes(seed)));
+ break;
+ case smile:
+ rng = new SmileRandom(seed);
+ break;
+ case smileMT:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister(
+ Primitives.hashCode(seed)));
+ break;
+ case smileMT64:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister64(seed));
+ break;
+ case commonsMath3MT:
+ rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister(
+ seed));
+ break;
+ default:
+ throw new IllegalStateException("Unexpected type: " + type);
+ }
+ return rng;
+ }
+
+ public enum PRNGType {
+ java, secure, smile, smileMT, smileMT64, commonsMath3MT;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/SmileRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/SmileRandom.java b/core/src/main/java/hivemall/math/random/SmileRandom.java
new file mode 100644
index 0000000..1edc56c
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/SmileRandom.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnull;
+
+import smile.math.random.RandomNumberGenerator;
+import smile.math.random.UniversalGenerator;
+
+public final class SmileRandom implements PRNG {
+
+ @Nonnull
+ private RandomNumberGenerator rng;
+
+ public SmileRandom() {
+ this.rng = new UniversalGenerator();
+ }
+
+ public SmileRandom(long seed) {
+ this.rng = new UniversalGenerator(seed);
+ }
+
+ public SmileRandom(@Nonnull RandomNumberGenerator rng) {
+ this.rng = rng;
+ }
+
+ @Override
+ public int nextInt(int n) {
+ return rng.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rng.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rng.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rng.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/AbstractVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java
new file mode 100644
index 0000000..88bed7b
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+
+public abstract class AbstractVector implements Vector {
+
+ public AbstractVector() {}
+
+ @Override
+ public double get(@Nonnegative final int index) {
+ return get(index, 0.d);
+ }
+
+ protected static final void checkIndex(final int index) {
+ if (index < 0) {
+ throw new IndexOutOfBoundsException("Invalid index " + index);
+ }
+ }
+
+ protected static final void checkIndex(final int index, final int size) {
+ if (index < 0 || index >= size) {
+ throw new IndexOutOfBoundsException("Index " + index + " out of bounds " + size);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/DenseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/DenseVector.java b/core/src/main/java/hivemall/math/vector/DenseVector.java
new file mode 100644
index 0000000..bd39af1
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/DenseVector.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class DenseVector extends AbstractVector {
+
+ @Nonnull
+ private final double[] values;
+ private final int size;
+
+ public DenseVector(@Nonnegative int size) {
+ super();
+ this.values = new double[size];
+ this.size = size;
+ }
+
+ public DenseVector(@Nonnull double[] values) {
+ super();
+ this.values = values;
+ this.size = values.length;
+ }
+
+ @Override
+ public double get(@Nonnegative final int index, final double defaultValue) {
+ checkIndex(index);
+ if (index >= size) {
+ return defaultValue;
+ }
+
+ return values[index];
+ }
+
+ @Override
+ public void set(@Nonnegative final int index, final double value) {
+ checkIndex(index, size);
+
+ values[index] = value;
+ }
+
+ @Override
+ public void incr(@Nonnegative final int index, final double delta) {
+ checkIndex(index, size);
+
+ values[index] += delta;
+ }
+
+ @Override
+ public void each(@Nonnull final VectorProcedure procedure) {
+ for (int i = 0; i < values.length; i++) {
+ procedure.apply(i, values[i]);
+ }
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public void clear() {
+ Arrays.fill(values, 0.d);
+ }
+
+ @Override
+ public double[] toArray() {
+ return values;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/SparseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/SparseVector.java b/core/src/main/java/hivemall/math/vector/SparseVector.java
new file mode 100644
index 0000000..072b544
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/SparseVector.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class SparseVector extends AbstractVector {
+
+ @Nonnull
+ private final SparseDoubleArray values;
+
+ public SparseVector() {
+ super();
+ this.values = new SparseDoubleArray();
+ }
+
+ public SparseVector(@Nonnull SparseDoubleArray values) {
+ super();
+ this.values = values;
+ }
+
+ @Override
+ public double get(@Nonnegative final int index, final double defaultValue) {
+ return values.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int index, final double value) {
+ values.put(index, value);
+ }
+
+ @Override
+ public void incr(@Nonnegative final int index, final double delta) {
+ values.increment(index, delta);
+ }
+
+ @Override
+ public void each(@Nonnull final VectorProcedure procedure) {
+ values.each(procedure);
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public void clear() {
+ values.clear();
+ }
+
+ @Override
+ public double[] toArray() {
+ return values.toArray();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/Vector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/Vector.java b/core/src/main/java/hivemall/math/vector/Vector.java
new file mode 100644
index 0000000..2e5107d
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/Vector.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public interface Vector {
+
+ public double get(@Nonnegative int index);
+
+ public double get(@Nonnegative int index, double defaultValue);
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative int index, double value);
+
+ public void incr(@Nonnegative int index, double delta);
+
+ public void each(@Nonnull VectorProcedure procedure);
+
+ public int size();
+
+ public void clear();
+
+ @Nonnull
+ public double[] toArray();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/VectorProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
new file mode 100644
index 0000000..266c531
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+
+public abstract class VectorProcedure {
+
+ public VectorProcedure() {}
+
+ public void apply(@Nonnegative int i, double value) {}
+
+ public void apply(@Nonnegative int i, int value) {}
+
+ public void apply(@Nonnegative int i) {}
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java b/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
deleted file mode 100644
index d2deda1..0000000
--- a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.collections.DoubleArrayList;
-import hivemall.utils.collections.IntArrayList;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-/**
- * Compressed Sparse Row Matrix.
- *
- * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
- * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
- */
-public final class CSRMatrixBuilder extends MatrixBuilder {
-
- @Nonnull
- private final IntArrayList rowPointers;
- @Nonnull
- private final IntArrayList columnIndices;
- @Nonnull
- private final DoubleArrayList values;
-
- private int maxNumColumns;
-
- public CSRMatrixBuilder(int initSize) {
- super();
- this.rowPointers = new IntArrayList(initSize + 1);
- rowPointers.add(0);
- this.columnIndices = new IntArrayList(initSize);
- this.values = new DoubleArrayList(initSize);
- this.maxNumColumns = 0;
- }
-
- @Override
- public CSRMatrixBuilder nextRow() {
- int ptr = values.size();
- rowPointers.add(ptr);
- return this;
- }
-
- @Override
- public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) {
- if (value == 0.d) {
- return this;
- }
-
- columnIndices.add(col);
- values.add(value);
- this.maxNumColumns = Math.max(col + 1, maxNumColumns);
- return this;
- }
-
- @Override
- public Matrix buildMatrix(boolean readOnly) {
- if (!readOnly) {
- throw new UnsupportedOperationException("Only readOnly matrix is supported");
- }
-
- ReadOnlyCSRMatrix matrix = new ReadOnlyCSRMatrix(rowPointers.toArray(true),
- columnIndices.toArray(true), values.toArray(true), maxNumColumns);
- return matrix;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java b/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
deleted file mode 100644
index f70616e..0000000
--- a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.collections.SparseDoubleArray;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public final class DenseMatrixBuilder extends MatrixBuilder {
-
- @Nonnull
- private final List<double[]> rows;
- private int maxNumColumns;
-
- @Nonnull
- private final SparseDoubleArray rowProbe;
-
- public DenseMatrixBuilder(int initSize) {
- super();
- this.rows = new ArrayList<double[]>(initSize);
- this.maxNumColumns = 0;
- this.rowProbe = new SparseDoubleArray(32);
- }
-
- @Override
- public MatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
- if (value == 0.d) {
- return this;
- }
- rowProbe.put(col, value);
- return this;
- }
-
- @Override
- public MatrixBuilder nextRow() {
- double[] row = rowProbe.toArray();
- rowProbe.clear();
- nextRow(row);
- return this;
- }
-
- @Override
- public void nextRow(@Nonnull double[] row) {
- rows.add(row);
- this.maxNumColumns = Math.max(row.length, maxNumColumns);
- }
-
- @Override
- public Matrix buildMatrix(boolean readOnly) {
- if (!readOnly) {
- throw new UnsupportedOperationException("Only readOnly matrix is supported");
- }
-
- int numRows = rows.size();
- double[][] data = rows.toArray(new double[numRows][]);
- return new ReadOnlyDenseMatrix2d(data, maxNumColumns);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/Matrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/Matrix.java b/core/src/main/java/hivemall/matrix/Matrix.java
deleted file mode 100644
index 8bbb6c5..0000000
--- a/core/src/main/java/hivemall/matrix/Matrix.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-
-public abstract class Matrix {
-
- private double defaultValue;
-
- public Matrix() {
- this.defaultValue = 0.d;
- }
-
- public abstract boolean readOnly();
-
- public void setDefaultValue(double value) {
- this.defaultValue = value;
- }
-
- @Nonnegative
- public abstract int numRows();
-
- @Nonnegative
- public abstract int numColumns();
-
- @Nonnegative
- public abstract int numColumns(@Nonnegative int row);
-
- /**
- * @throws IndexOutOfBoundsException
- */
- public final double get(@Nonnegative final int row, @Nonnegative final int col) {
- return get(row, col, defaultValue);
- }
-
- /**
- * @throws IndexOutOfBoundsException
- */
- public abstract double get(@Nonnegative int row, @Nonnegative int col, double defaultValue);
-
- /**
- * @throws IndexOutOfBoundsException
- * @throws UnsupportedOperationException
- */
- public abstract void set(@Nonnegative int row, @Nonnegative int col, double value);
-
- /**
- * @throws IndexOutOfBoundsException
- * @throws UnsupportedOperationException
- */
- public abstract double getAndSet(@Nonnegative int row, @Nonnegative final int col, double value);
-
- protected static final void checkRowIndex(final int row, final int numRows) {
- if (row < 0 || row >= numRows) {
- throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
- }
- }
-
- protected static final void checkColIndex(final int col, final int numColumns) {
- if (col < 0 || col >= numColumns) {
- throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
- }
- }
-
- protected static final void checkIndex(final int row, final int col, final int numRows,
- final int numColumns) {
- if (row < 0 || row >= numRows) {
- throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
- }
- if (col < 0 || col >= numColumns) {
- throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/MatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/MatrixBuilder.java b/core/src/main/java/hivemall/matrix/MatrixBuilder.java
deleted file mode 100644
index e4d6233..0000000
--- a/core/src/main/java/hivemall/matrix/MatrixBuilder.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public abstract class MatrixBuilder {
-
- public MatrixBuilder() {}
-
- public void nextRow(@Nonnull final double[] row) {
- for (int col = 0; col < row.length; col++) {
- nextColumn(col, row[col]);
- }
- nextRow();
- }
-
- public void nextRow(@Nonnull final String[] row) {
- for (String col : row) {
- if (col == null) {
- continue;
- }
- nextColumn(col);
- }
- nextRow();
- }
-
- @Nonnull
- public abstract MatrixBuilder nextRow();
-
- @Nonnull
- public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value);
-
- /**
- * @throws IllegalArgumentException
- * @throws NumberFormatException
- */
- @Nonnull
- public MatrixBuilder nextColumn(@Nonnull final String col) {
- final int pos = col.indexOf(':');
- if (pos == 0) {
- throw new IllegalArgumentException("Invalid feature value representation: " + col);
- }
-
- final String feature;
- final double value;
- if (pos > 0) {
- feature = col.substring(0, pos);
- String s2 = col.substring(pos + 1);
- value = Double.parseDouble(s2);
- } else {
- feature = col;
- value = 1.d;
- }
-
- if (feature.indexOf(':') != -1) {
- throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col);
- }
-
- int colIndex = Integer.parseInt(feature);
- if (colIndex < 0) {
- throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: "
- + colIndex);
- }
-
- return nextColumn(colIndex, value);
- }
-
- @Nonnull
- public abstract Matrix buildMatrix(boolean readOnly);
-
-}