You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/10/16 11:35:23 UTC
[2/2] incubator-hivemall git commit: Close #121: [HIVEMALL-151]
Support Matrix conversion from DoK to CSR/CSC matrix
Close #121: [HIVEMALL-151] Support Matrix conversion from DoK to CSR/CSC matrix
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/fdf70214
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/fdf70214
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/fdf70214
Branch: refs/heads/master
Commit: fdf70214359f3ce2b1371edf630be89ba9499745
Parents: d4f4ab9
Author: Makoto Yui <my...@apache.org>
Authored: Mon Oct 16 20:35:00 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Mon Oct 16 20:35:00 2017 +0900
----------------------------------------------------------------------
.travis.yml | 3 +-
.../anomaly/SingularSpectrumTransform.java | 6 +-
.../java/hivemall/ftvec/AddFeatureIndexUDF.java | 2 +-
.../java/hivemall/ftvec/FeatureIndexUDF.java | 2 +-
.../ftvec/trans/AddFieldIndicesUDF.java | 6 +-
.../hivemall/ftvec/trans/FFMFeaturesUDF.java | 10 +-
.../hivemall/math/matrix/AbstractMatrix.java | 5 +
.../math/matrix/ColumnMajorFloatMatrix.java | 32 ++
.../java/hivemall/math/matrix/FloatMatrix.java | 73 ++++
.../main/java/hivemall/math/matrix/Matrix.java | 2 +
.../java/hivemall/math/matrix/MatrixUtils.java | 264 +++++++++++-
.../math/matrix/RowMajorFloatMatrix.java | 32 ++
.../math/matrix/builders/CSCMatrixBuilder.java | 8 +-
.../hivemall/math/matrix/sparse/CSCMatrix.java | 47 ++-
.../hivemall/math/matrix/sparse/CSRMatrix.java | 6 +-
.../math/matrix/sparse/DoKFloatMatrix.java | 368 -----------------
.../hivemall/math/matrix/sparse/DoKMatrix.java | 37 +-
.../matrix/sparse/floats/CSCFloatMatrix.java | 317 +++++++++++++++
.../matrix/sparse/floats/CSRFloatMatrix.java | 293 ++++++++++++++
.../matrix/sparse/floats/DoKFloatMatrix.java | 401 +++++++++++++++++++
.../hivemall/math/vector/AbstractVector.java | 10 +
.../hivemall/math/vector/DenseFloatVector.java | 107 +++++
.../hivemall/math/vector/SparseFloatVector.java | 86 ++++
.../main/java/hivemall/math/vector/Vector.java | 7 +
.../hivemall/math/vector/VectorProcedure.java | 4 +
.../main/java/hivemall/recommend/SlimUDTF.java | 18 +-
.../hivemall/smile/utils/SmileExtUtils.java | 12 +-
.../collections/arrays/SparseFloatArray.java | 9 +
.../hivemall/math/matrix/MatrixUtilsTest.java | 132 ++++++
.../math/matrix/sparse/DoKFloatMatrixTest.java | 43 --
.../math/matrix/sparse/DoKMatrixTest.java | 43 ++
.../sparse/floats/DoKFloatMatrixTest.java | 60 +++
docs/gitbook/getting_started/installation.md | 2 +-
pom.xml | 4 +-
resources/ddl/define-all-as-permanent.hive | 4 +
resources/ddl/define-all.hive | 4 +
resources/ddl/define-all.spark | 4 +
resources/ddl/define-udfs.td.hql | 3 +-
.../ftvec/AddFeatureIndexUDFWrapper.java | 2 +-
39 files changed, 1992 insertions(+), 476 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index c64c5ff..c98fe0c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -18,9 +18,10 @@ env:
language: java
jdk:
- - openjdk7
+# - openjdk7
# - oraclejdk7
- oraclejdk8
+# - oraclejdk9
branches:
only:
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
index 34d85aa..1936da4 100644
--- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
@@ -186,14 +186,14 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
for (int i = 0; i < k; i++) {
map.put(eigvals[i], i);
}
- Iterator<Integer> indicies = map.values().iterator();
+ Iterator<Integer> indices = map.values().iterator();
double s = 0.d;
for (int i = 0; i < r; i++) {
- if (!indicies.hasNext()) {
+ if (!indices.hasNext()) {
throw new IllegalStateException("Should not happen");
}
- double v = eigvecs.getEntry(0, indicies.next().intValue());
+ double v = eigvecs.getEntry(0, indices.next().intValue());
s += v * v;
}
return 1.d - Math.sqrt(s);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java b/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java
index 105dd2a..21b3514 100644
--- a/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java
+++ b/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java
@@ -37,7 +37,7 @@ import org.apache.hadoop.io.Text;
*/
@Description(
name = "add_feature_index",
- value = "_FUNC_(ARRAY[DOUBLE]: dense feature vector) - Returns a feature vector with feature indicies")
+ value = "_FUNC_(ARRAY[DOUBLE]: dense feature vector) - Returns a feature vector with feature indices")
@UDFType(deterministic = true, stateful = false)
public final class AddFeatureIndexUDF extends UDF {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java b/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java
index 9ffe6c6..9fdbc01 100644
--- a/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java
+++ b/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java
@@ -32,7 +32,7 @@ import org.apache.hadoop.io.IntWritable;
@Description(
name = "feature_index",
- value = "_FUNC_(feature_vector in array<string>) - Returns feature indicies in array<index>")
+ value = "_FUNC_(feature_vector in array<string>) - Returns feature indices in array<index>")
@UDFType(deterministic = true, stateful = false)
public final class FeatureIndexUDF extends UDF {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
index 53b998c..99cf785 100644
--- a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
@@ -37,8 +37,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-@Description(name = "add_field_indicies", value = "_FUNC_(array<string> features) "
- + "- Returns arrays of string that field indicies (<field>:<feature>)* are argumented")
+@Description(name = "add_field_indices", value = "_FUNC_(array<string> features) "
+ + "- Returns arrays of string that field indices (<field>:<feature>)* are argumented")
@UDFType(deterministic = true, stateful = false)
public final class AddFieldIndicesUDF extends GenericUDF {
@@ -82,7 +82,7 @@ public final class AddFieldIndicesUDF extends GenericUDF {
@Override
public String getDisplayString(String[] args) {
- return "add_field_indicies( " + Arrays.toString(args) + " )";
+ return "add_field_indices( " + Arrays.toString(args) + " )";
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
index eead738..a0acd36 100644
--- a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
@@ -60,7 +60,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
private boolean _mhash = true;
private int _numFeatures = Feature.DEFAULT_NUM_FEATURES;
private int _numFields = Feature.DEFAULT_NUM_FIELDS;
- private boolean _emitIndicies = false;
+ private boolean _emitIndices = false;
@Override
protected Options getOptions() {
@@ -72,7 +72,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
opts.addOption("hash", "feature_hashing", true,
"The number of bits for feature hashing in range [18,31] [default:21]");
opts.addOption("fields", "num_fields", true, "The number of fields [default:1024]");
- opts.addOption("emit_indicies", false, "Emit indicies for fields [default: false]");
+ opts.addOption("emit_indices", false, "Emit indices for fields [default: false]");
return opts;
}
@@ -100,7 +100,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
}
this._numFields = numFields;
- this._emitIndicies = cl.hasOption("emit_indicies");
+ this._emitIndices = cl.hasOption("emit_indices");
return cl;
}
@@ -189,14 +189,14 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
// categorical feature representation
final String fv;
if (_mhash) {
- int field = _emitIndicies ? i : MurmurHash3.murmurhash3(_featureNames[i],
+ int field = _emitIndices ? i : MurmurHash3.murmurhash3(_featureNames[i],
_numFields);
// +NUM_FIELD to avoid conflict to quantitative features
int index = MurmurHash3.murmurhash3(feature, _numFeatures) + _numFields;
fv = builder.append(field).append(':').append(index).append(":1").toString();
StringUtils.clear(builder);
} else {
- if (_emitIndicies) {
+ if (_emitIndices) {
builder.append(i);
} else {
builder.append(featureName);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
index 2ee27f7..fe3c543 100644
--- a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
@@ -102,4 +102,9 @@ public abstract class AbstractMatrix implements Matrix {
eachInColumn(col, procedure, false);
}
+ @Override
+ public void eachNonZeroCell(VectorProcedure procedure) {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java
new file mode 100644
index 0000000..6067ed3
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+public abstract class ColumnMajorFloatMatrix extends ColumnMajorMatrix implements FloatMatrix {
+
+ public ColumnMajorFloatMatrix() {
+ super();
+ }
+
+ @Override
+ public ColumnMajorFloatMatrix toColumnMajorMatrix() {
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/FloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/FloatMatrix.java b/core/src/main/java/hivemall/math/matrix/FloatMatrix.java
new file mode 100644
index 0000000..f1af65f
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/FloatMatrix.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public interface FloatMatrix extends Matrix {
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public float get(@Nonnegative final int row, @Nonnegative final int col,
+ final float defaultValue);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final float value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value);
+
+ /**
+ * @return returns dst
+ */
+ @Nonnull
+ public float[] getRow(@Nonnegative int index, @Nonnull float[] dst);
+
+ @Override
+ default double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ return get(row, col, (float) defaultValue);
+ }
+
+ @Override
+ default void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ set(row, col, (float) value);
+ }
+
+ @Override
+ default double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ return getAndSet(row, col, (float) value);
+ }
+
+ @Override
+ public RowMajorFloatMatrix toRowMajorMatrix();
+
+ @Override
+ public ColumnMajorFloatMatrix toColumnMajorMatrix();
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/Matrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/Matrix.java b/core/src/main/java/hivemall/math/matrix/Matrix.java
index 8a4782a..338a4c2 100644
--- a/core/src/main/java/hivemall/math/matrix/Matrix.java
+++ b/core/src/main/java/hivemall/math/matrix/Matrix.java
@@ -115,6 +115,8 @@ public interface Matrix {
public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure);
+
@Nonnull
public RowMajorMatrix toRowMajorMatrix();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
index 90ce78f..cd137ed 100644
--- a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
+++ b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
@@ -20,10 +20,17 @@ package hivemall.math.matrix;
import hivemall.math.matrix.builders.MatrixBuilder;
import hivemall.math.matrix.ints.IntMatrix;
+import hivemall.math.matrix.sparse.CSCMatrix;
+import hivemall.math.matrix.sparse.CSRMatrix;
+import hivemall.math.matrix.sparse.floats.CSCFloatMatrix;
+import hivemall.math.matrix.sparse.floats.CSRFloatMatrix;
import hivemall.math.vector.VectorProcedure;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.mutable.MutableInt;
+import java.util.Arrays;
+import java.util.Comparator;
+
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
@@ -34,7 +41,7 @@ public final class MatrixUtils {
@Nonnull
public static Matrix shuffle(@Nonnull final Matrix m, @Nonnull final int[] indices) {
Preconditions.checkArgument(m.numRows() <= indices.length, "m.numRow() `" + m.numRows()
- + "` MUST be equals to or less than |swapIndicies| `" + indices.length + "`");
+ + "` MUST be equals to or less than |swapIndices| `" + indices.length + "`");
final MatrixBuilder builder = m.builder();
final VectorProcedure proc = new VectorProcedure() {
@@ -70,4 +77,259 @@ public final class MatrixUtils {
return which.getValue();
}
+ /**
+ * @param data non-zero entries
+ */
+ @Nonnull
+ public static CSRMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final double[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortColumns) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndices = new int[nnz];
+ final double[] values = new double[nnz];
+
+ coo2csr(rows, cols, data, rowPointers, colIndices, values, numRows, numCols, nnz);
+
+ if (sortColumns) {
+ sortIndices(rowPointers, colIndices, values);
+ }
+ return new CSRMatrix(rowPointers, colIndices, values, numCols);
+ }
+
+ /**
+ * @param data non-zero entries
+ */
+ @Nonnull
+ public static CSRFloatMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final float[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortColumns) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndices = new int[nnz];
+ final float[] values = new float[nnz];
+
+ coo2csr(rows, cols, data, rowPointers, colIndices, values, numRows, numCols, nnz);
+
+ if (sortColumns) {
+ sortIndices(rowPointers, colIndices, values);
+ }
+ return new CSRFloatMatrix(rowPointers, colIndices, values, numCols);
+ }
+
+ @Nonnull
+ public static CSCMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final double[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortRows) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] columnPointers = new int[numCols + 1];
+ final int[] rowIndices = new int[nnz];
+ final double[] values = new double[nnz];
+
+ coo2csr(cols, rows, data, columnPointers, rowIndices, values, numCols, numRows, nnz);
+
+ if (sortRows) {
+ sortIndices(columnPointers, rowIndices, values);
+ }
+ return new CSCMatrix(columnPointers, rowIndices, values, numRows, numCols);
+ }
+
+ @Nonnull
+ public static CSCFloatMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final float[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortRows) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] columnPointers = new int[numCols + 1];
+ final int[] rowIndices = new int[nnz];
+ final float[] values = new float[nnz];
+
+ coo2csr(cols, rows, data, columnPointers, rowIndices, values, numCols, numRows, nnz);
+
+ if (sortRows) {
+ sortIndices(columnPointers, rowIndices, values);
+ }
+
+ return new CSCFloatMatrix(columnPointers, rowIndices, values, numRows, numCols);
+ }
+
+ private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final double[] data, @Nonnull final int[] rowPointers,
+ @Nonnull final int[] colIndices, @Nonnull final double[] values,
+ @Nonnegative final int numRows, @Nonnegative final int numCols, final int nnz) {
+ // compute nnz per for each row to get rowPointers
+ for (int n = 0; n < nnz; n++) {
+ rowPointers[rows[n]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ // copy cols, data to colIndices, csrValues
+ for (int n = 0; n < nnz; n++) {
+ int row = rows[n];
+ int dst = rowPointers[row];
+
+ colIndices[dst] = cols[n];
+ values[dst] = data[n];
+
+ rowPointers[row]++;
+ }
+
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+ }
+
+ private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
+ @Nonnull final float[] data, @Nonnull final int[] rowPointers,
+ @Nonnull final int[] colIndices, @Nonnull final float[] values,
+ @Nonnegative final int numRows, @Nonnegative final int numCols, final int nnz) {
+ // compute nnz per for each row to get rowPointers
+ for (int n = 0; n < nnz; n++) {
+ rowPointers[rows[n]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ // copy cols, data to colIndices, csrValues
+ for (int n = 0; n < nnz; n++) {
+ int row = rows[n];
+ int dst = rowPointers[row];
+
+ colIndices[dst] = cols[n];
+ values[dst] = data[n];
+
+ rowPointers[row]++;
+ }
+
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+ }
+
+ private static void sortIndices(@Nonnull final int[] majorAxisPointers,
+ @Nonnull final int[] minorAxisIndices, @Nonnull final double[] values) {
+ final int numRows = majorAxisPointers.length - 1;
+ if (numRows <= 1) {
+ return;
+ }
+
+ for (int i = 0; i < numRows; i++) {
+ final int rowStart = majorAxisPointers[i];
+ final int rowEnd = majorAxisPointers[i + 1];
+
+ final int numCols = rowEnd - rowStart;
+ if (numCols == 0) {
+ continue;
+ } else if (numCols < 0) {
+ throw new IllegalArgumentException(
+ "numCols SHOULD be greater than zero. numCols = rowEnd - rowStart = " + rowEnd
+ + " - " + rowStart + " = " + numCols + " at i=" + i);
+ }
+
+ final IntDoublePair[] pairs = new IntDoublePair[numCols];
+ for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
+ pairs[n] = new IntDoublePair(minorAxisIndices[jj], values[jj]);
+ }
+
+ Arrays.sort(pairs, new Comparator<IntDoublePair>() {
+ @Override
+ public int compare(IntDoublePair x, IntDoublePair y) {
+ return Integer.compare(x.key, y.key);
+ }
+ });
+
+ for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
+ IntDoublePair tmp = pairs[n];
+ minorAxisIndices[jj] = tmp.key;
+ values[jj] = tmp.value;
+ }
+ }
+ }
+
+ private static void sortIndices(@Nonnull final int[] majorAxisPointers,
+ @Nonnull final int[] minorAxisIndices, @Nonnull final float[] values) {
+ final int numRows = majorAxisPointers.length - 1;
+ if (numRows <= 1) {
+ return;
+ }
+
+ for (int i = 0; i < numRows; i++) {
+ final int rowStart = majorAxisPointers[i];
+ final int rowEnd = majorAxisPointers[i + 1];
+
+ final int numCols = rowEnd - rowStart;
+ if (numCols == 0) {
+ continue;
+ } else if (numCols < 0) {
+ throw new IllegalArgumentException(
+ "numCols SHOULD be greater than or equal to zero. numCols = rowEnd - rowStart = "
+ + rowEnd + " - " + rowStart + " = " + numCols + " at i=" + i);
+ }
+
+ final IntFloatPair[] pairs = new IntFloatPair[numCols];
+ for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
+ pairs[n] = new IntFloatPair(minorAxisIndices[jj], values[jj]);
+ }
+
+ Arrays.sort(pairs, new Comparator<IntFloatPair>() {
+ @Override
+ public int compare(IntFloatPair x, IntFloatPair y) {
+ return Integer.compare(x.key, y.key);
+ }
+ });
+
+ for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
+ IntFloatPair tmp = pairs[n];
+ minorAxisIndices[jj] = tmp.key;
+ values[jj] = tmp.value;
+ }
+ }
+ }
+
+ private static final class IntDoublePair {
+
+ final int key;
+ final double value;
+
+ IntDoublePair(int key, double value) {
+ this.key = key;
+ this.value = value;
+ }
+ }
+
+ private static final class IntFloatPair {
+
+ final int key;
+ final float value;
+
+ IntFloatPair(int key, float value) {
+ this.key = key;
+ this.value = value;
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java
new file mode 100644
index 0000000..90f7bbf
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+public abstract class RowMajorFloatMatrix extends RowMajorMatrix implements FloatMatrix {
+
+ public RowMajorFloatMatrix() {
+ super();
+ }
+
+ @Override
+ public RowMajorFloatMatrix toRowMajorMatrix() {
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
index df2bff7..5c546d5 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
@@ -70,19 +70,19 @@ public final class CSCMatrixBuilder extends MatrixBuilder {
}
final int[] columnIndices = cols.toArray(true);
- final int[] rowsIndicies = rows.toArray(true);
+ final int[] rowsIndices = rows.toArray(true);
final double[] valuesArray = values.toArray(true);
// convert to column major
final int nnz = valuesArray.length;
SortObj[] sortObjs = new SortObj[nnz];
for (int i = 0; i < nnz; i++) {
- sortObjs[i] = new SortObj(columnIndices[i], rowsIndicies[i], valuesArray[i]);
+ sortObjs[i] = new SortObj(columnIndices[i], rowsIndices[i], valuesArray[i]);
}
Arrays.sort(sortObjs);
for (int i = 0; i < nnz; i++) {
columnIndices[i] = sortObjs[i].columnIndex;
- rowsIndicies[i] = sortObjs[i].rowsIndex;
+ rowsIndices[i] = sortObjs[i].rowsIndex;
valuesArray[i] = sortObjs[i].value;
}
sortObjs = null;
@@ -98,7 +98,7 @@ public final class CSCMatrixBuilder extends MatrixBuilder {
}
columnPointers[maxNumColumns] = nnz; // nnz
- return new CSCMatrix(columnPointers, rowsIndicies, valuesArray, row, maxNumColumns);
+ return new CSCMatrix(columnPointers, rowsIndices, valuesArray, row, maxNumColumns);
}
private static final class SortObj implements Comparable<SortObj> {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
index f8eb02f..14bb4f9 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -31,7 +31,7 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
- * Compressed Sparse Column matrix optimized for colum major access.
+ * Compressed Sparse Column matrix optimized for column major access.
*
* @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
*/
@@ -40,7 +40,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
@Nonnull
private final int[] columnPointers;
@Nonnull
- private final int[] rowIndicies;
+ private final int[] rowIndices;
@Nonnull
private final double[] values;
@@ -48,15 +48,15 @@ public final class CSCMatrix extends ColumnMajorMatrix {
private final int numColumns;
private final int nnz;
- public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndicies,
+ public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndices,
@Nonnull double[] values, int numRows, int numColumns) {
super();
Preconditions.checkArgument(columnPointers.length >= 1,
"rowPointers must be greather than 0: " + columnPointers.length);
- Preconditions.checkArgument(rowIndicies.length == values.length, "#rowIndicies ("
- + rowIndicies.length + ") must be equals to #values (" + values.length + ")");
+ Preconditions.checkArgument(rowIndices.length == values.length, "#rowIndices ("
+ + rowIndices.length + ") must be equals to #values (" + values.length + ")");
this.columnPointers = columnPointers;
- this.rowIndicies = rowIndicies;
+ this.rowIndices = rowIndices;
this.values = values;
this.numRows = numRows;
this.numColumns = numColumns;
@@ -97,18 +97,18 @@ public final class CSCMatrix extends ColumnMajorMatrix {
public int numColumns(final int row) {
checkRowIndex(row, numRows);
- return ArrayUtils.count(rowIndicies, row);
+ return ArrayUtils.count(rowIndices, row);
}
@Override
- public double[] getRow(int index) {
+ public double[] getRow(final int index) {
checkRowIndex(index, numRows);
final double[] row = new double[numColumns];
final int numCols = columnPointers.length - 1;
for (int j = 0; j < numCols; j++) {
- final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
columnPointers[j + 1], index);
if (k >= 0) {
row[j] = values[k];
@@ -124,12 +124,17 @@ public final class CSCMatrix extends ColumnMajorMatrix {
final int last = Math.min(dst.length, columnPointers.length - 1);
for (int j = 0; j < last; j++) {
- final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
columnPointers[j + 1], index);
if (k >= 0) {
dst[j] = values[k];
+ } else {
+ dst[j] = 0.d;
}
}
+ for (int j = last; j < dst.length; j++) {
+ dst[j] = 0.d;
+ }
return dst;
}
@@ -140,7 +145,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
row.clear();
for (int j = 0, last = columnPointers.length - 1; j < last; j++) {
- final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
columnPointers[j + 1], index);
if (k >= 0) {
double v = values[k];
@@ -190,7 +195,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
int leftIn = columnPointers[col];
int rightEx = columnPointers[col + 1];
- final int index = Arrays.binarySearch(rowIndicies, leftIn, rightEx, row);
+ final int index = Arrays.binarySearch(rowIndices, leftIn, rightEx, row);
if (index >= 0 && index >= values.length) {
throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ values.length);
@@ -213,7 +218,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
if (nullOutput) {
for (int row = 0, i = startIn; row < numRows; row++) {
- if (i < endEx && row == rowIndicies[i]) {
+ if (i < endEx && row == rowIndices[i]) {
double v = values[i++];
procedure.apply(row, v);
} else {
@@ -222,7 +227,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
}
} else {
for (int j = startIn; j < endEx; j++) {
- int row = rowIndicies[j];
+ int row = rowIndices[j];
double v = values[j];
procedure.apply(row, v);
}
@@ -236,7 +241,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
final int startIn = columnPointers[col];
final int endEx = columnPointers[col + 1];
for (int j = startIn; j < endEx; j++) {
- int row = rowIndicies[j];
+ int row = rowIndices[j];
final double v = values[j];
if (v != 0.d) {
procedure.apply(row, v);
@@ -247,12 +252,12 @@ public final class CSCMatrix extends ColumnMajorMatrix {
@Override
public CSRMatrix toRowMajorMatrix() {
final int[] rowPointers = new int[numRows + 1];
- final int[] colIndicies = new int[nnz];
+ final int[] colIndices = new int[nnz];
final double[] csrValues = new double[nnz];
// compute nnz per for each row
- for (int i = 0; i < rowIndicies.length; i++) {
- rowPointers[rowIndicies[i]]++;
+ for (int i = 0; i < rowIndices.length; i++) {
+ rowPointers[rowIndices[i]]++;
}
for (int i = 0, sum = 0; i < numRows; i++) {
int curr = rowPointers[i];
@@ -263,10 +268,10 @@ public final class CSCMatrix extends ColumnMajorMatrix {
for (int j = 0; j < numColumns; j++) {
for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) {
- int col = rowIndicies[i];
+ int col = rowIndices[i];
int dst = rowPointers[col];
- colIndicies[dst] = j;
+ colIndices[dst] = j;
csrValues[dst] = values[i];
rowPointers[col]++;
@@ -280,7 +285,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
last = tmp;
}
- return new CSRMatrix(rowPointers, colIndicies, csrValues, numColumns);
+ return new CSRMatrix(rowPointers, colIndices, csrValues, numColumns);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
index 805bbd1..c1fa6e4 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
@@ -238,7 +238,7 @@ public final class CSRMatrix extends RowMajorMatrix {
@Nonnull
public CSCMatrix toColumnMajorMatrix() {
final int[] columnPointers = new int[numColumns + 1];
- final int[] rowIndicies = new int[nnz];
+ final int[] rowIndices = new int[nnz];
final double[] cscValues = new double[nnz];
// compute nnz per for each column
@@ -257,7 +257,7 @@ public final class CSRMatrix extends RowMajorMatrix {
int col = columnIndices[j];
int dst = columnPointers[col];
- rowIndicies[dst] = i;
+ rowIndices[dst] = i;
cscValues[dst] = values[j];
columnPointers[col]++;
@@ -271,7 +271,7 @@ public final class CSRMatrix extends RowMajorMatrix {
last = tmp;
}
- return new CSCMatrix(columnPointers, rowIndicies, cscValues, numRows, numColumns);
+ return new CSCMatrix(columnPointers, rowIndices, cscValues, numRows, numColumns);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
deleted file mode 100644
index 16b4b64..0000000
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
+++ /dev/null
@@ -1,368 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.math.matrix.sparse;
-
-import hivemall.annotations.Experimental;
-import hivemall.math.matrix.AbstractMatrix;
-import hivemall.math.matrix.ColumnMajorMatrix;
-import hivemall.math.matrix.RowMajorMatrix;
-import hivemall.math.matrix.builders.DoKMatrixBuilder;
-import hivemall.math.vector.Vector;
-import hivemall.math.vector.VectorProcedure;
-import hivemall.utils.collections.maps.Long2FloatOpenHashTable;
-import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator;
-import hivemall.utils.lang.Preconditions;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-/**
- * Dictionary Of Keys based sparse matrix.
- *
- * This is an efficient structure for constructing a sparse matrix incrementally.
- */
-@Experimental
-public final class DoKFloatMatrix extends AbstractMatrix {
-
- @Nonnull
- private final Long2FloatOpenHashTable elements;
- @Nonnegative
- private int numRows;
- @Nonnegative
- private int numColumns;
- @Nonnegative
- private int nnz;
-
- public DoKFloatMatrix() {
- this(0, 0);
- }
-
- public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
- this(numRows, numCols, 0.05f);
- }
-
- public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
- @Nonnegative float sparsity) {
- super();
- Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
- + sparsity);
- int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
- this.elements = new Long2FloatOpenHashTable(initialCapacity);
- elements.defaultReturnValue(0.f);
- this.numRows = numRows;
- this.numColumns = numCols;
- this.nnz = 0;
- }
-
- public DoKFloatMatrix(@Nonnegative int initSize) {
- super();
- int initialCapacity = Math.max(initSize, 16384);
- this.elements = new Long2FloatOpenHashTable(initialCapacity);
- elements.defaultReturnValue(0.f);
- this.numRows = 0;
- this.numColumns = 0;
- this.nnz = 0;
- }
-
- @Override
- public boolean isSparse() {
- return true;
- }
-
- @Override
- public boolean isRowMajorMatrix() {
- return false;
- }
-
- @Override
- public boolean isColumnMajorMatrix() {
- return false;
- }
-
- @Override
- public boolean readOnly() {
- return false;
- }
-
- @Override
- public boolean swappable() {
- return true;
- }
-
- @Override
- public int nnz() {
- return nnz;
- }
-
- @Override
- public int numRows() {
- return numRows;
- }
-
- @Override
- public int numColumns() {
- return numColumns;
- }
-
- @Override
- public int numColumns(@Nonnegative final int row) {
- int count = 0;
- for (int j = 0; j < numColumns; j++) {
- long index = index(row, j);
- if (elements.containsKey(index)) {
- count++;
- }
- }
- return count;
- }
-
- @Override
- public double[] getRow(@Nonnegative final int index) {
- double[] dst = row();
- return getRow(index, dst);
- }
-
- @Override
- public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
- checkRowIndex(row, numRows);
-
- final int end = Math.min(dst.length, numColumns);
- for (int col = 0; col < end; col++) {
- long k = index(row, col);
- float v = elements.get(k);
- dst[col] = v;
- }
-
- return dst;
- }
-
- @Override
- public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
- checkRowIndex(index, numRows);
- row.clear();
-
- for (int col = 0; col < numColumns; col++) {
- long k = index(index, col);
- final float v = elements.get(k, 0.f);
- if (v != 0.f) {
- row.set(col, v);
- }
- }
- }
-
- @Override
- public double get(@Nonnegative final int row, @Nonnegative final int col,
- final double defaultValue) {
- return get(row, col, (float) defaultValue);
- }
-
- public float get(@Nonnegative final int row, @Nonnegative final int col,
- final float defaultValue) {
- long index = index(row, col);
- return elements.get(index, defaultValue);
- }
-
- @Override
- public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
- set(row, col, (float) value);
- }
-
- public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
- checkIndex(row, col);
-
- final long index = index(row, col);
- if (value == 0.f && elements.containsKey(index) == false) {
- return;
- }
-
- if (elements.put(index, value, 0.f) == 0.f) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
- }
-
- @Override
- public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
- final double value) {
- return getAndSet(row, col, (float) value);
- }
-
- public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) {
- checkIndex(row, col);
-
- final long index = index(row, col);
- if (value == 0.f && elements.containsKey(index) == false) {
- return 0.f;
- }
-
- final float old = elements.put(index, value, 0.f);
- if (old == 0.f) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
- return old;
- }
-
- @Override
- public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
- checkRowIndex(row1, numRows);
- checkRowIndex(row2, numRows);
-
- for (int j = 0; j < numColumns; j++) {
- final long i1 = index(row1, j);
- final long i2 = index(row2, j);
-
- final int k1 = elements._findKey(i1);
- final int k2 = elements._findKey(i2);
-
- if (k1 >= 0) {
- if (k2 >= 0) {
- float v1 = elements._get(k1);
- float v2 = elements._set(k2, v1);
- elements._set(k1, v2);
- } else {// k1>=0 and k2<0
- float v1 = elements._remove(k1);
- elements.put(i2, v1);
- }
- } else if (k2 >= 0) {// k2>=0 and k1 < 0
- float v2 = elements._remove(k2);
- elements.put(i1, v2);
- } else {//k1<0 and k2<0
- continue;
- }
- }
- }
-
- @Override
- public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
- final boolean nullOutput) {
- checkRowIndex(row, numRows);
-
- for (int col = 0; col < numColumns; col++) {
- long i = index(row, col);
- final int key = elements._findKey(i);
- if (key < 0) {
- if (nullOutput) {
- procedure.apply(col, 0.d);
- }
- } else {
- float v = elements._get(key);
- procedure.apply(col, v);
- }
- }
- }
-
- @Override
- public void eachNonZeroInRow(@Nonnegative final int row,
- @Nonnull final VectorProcedure procedure) {
- checkRowIndex(row, numRows);
-
- for (int col = 0; col < numColumns; col++) {
- long i = index(row, col);
- final float v = elements.get(i, 0.f);
- if (v != 0.f) {
- procedure.apply(col, v);
- }
- }
- }
-
- @Override
- public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
- checkRowIndex(row, numRows);
-
- for (int col = 0; col < numColumns; col++) {
- long i = index(row, col);
- final int key = elements._findKey(i);
- if (key != -1) {
- procedure.apply(col);
- }
- }
- }
-
- @Override
- public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
- final boolean nullOutput) {
- checkColIndex(col, numColumns);
-
- for (int row = 0; row < numRows; row++) {
- long i = index(row, col);
- final int key = elements._findKey(i);
- if (key < 0) {
- if (nullOutput) {
- procedure.apply(row, 0.d);
- }
- } else {
- float v = elements._get(key);
- procedure.apply(row, v);
- }
- }
- }
-
- @Override
- public void eachNonZeroInColumn(@Nonnegative final int col,
- @Nonnull final VectorProcedure procedure) {
- checkColIndex(col, numColumns);
-
- for (int row = 0; row < numRows; row++) {
- long i = index(row, col);
- final float v = elements.get(i, 0.f);
- if (v != 0.f) {
- procedure.apply(row, v);
- }
- }
- }
-
- public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
- if (nnz == 0) {
- return;
- }
- final IMapIterator itor = elements.entries();
- while (itor.next() != -1) {
- long k = itor.getKey();
- int row = Primitives.getHigh(k);
- int col = Primitives.getLow(k);
- float value = itor.getValue();
- procedure.apply(row, col, value);
- }
- }
-
- @Override
- public RowMajorMatrix toRowMajorMatrix() {
- throw new UnsupportedOperationException("Not yet supported");
- }
-
- @Override
- public ColumnMajorMatrix toColumnMajorMatrix() {
- throw new UnsupportedOperationException("Not yet supported");
- }
-
- @Override
- public DoKMatrixBuilder builder() {
- return new DoKMatrixBuilder(elements.size());
- }
-
- @Nonnegative
- private static long index(@Nonnegative final int row, @Nonnegative final int col) {
- return Primitives.toLong(row, col);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
index 054d62a..6dc0502 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -21,6 +21,7 @@ package hivemall.math.matrix.sparse;
import hivemall.annotations.Experimental;
import hivemall.math.matrix.AbstractMatrix;
import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.MatrixUtils;
import hivemall.math.matrix.RowMajorMatrix;
import hivemall.math.matrix.builders.DoKMatrixBuilder;
import hivemall.math.vector.Vector;
@@ -333,12 +334,44 @@ public final class DoKMatrix extends AbstractMatrix {
@Override
public RowMajorMatrix toRowMajorMatrix() {
- throw new UnsupportedOperationException("Not yet supported");
+ final int nnz = elements.size();
+ final int[] rows = new int[nnz];
+ final int[] cols = new int[nnz];
+ final double[] data = new double[nnz];
+
+ final IMapIterator itor = elements.entries();
+ for (int i = 0; i < nnz; i++) {
+ if (itor.next() == -1) {
+ throw new IllegalStateException("itor.next() returns -1 where i=" + i);
+ }
+ long k = itor.getKey();
+ rows[i] = Primitives.getHigh(k);
+ cols[i] = Primitives.getLow(k);
+ data[i] = itor.getValue();
+ }
+
+ return MatrixUtils.coo2csr(rows, cols, data, numRows, numColumns, true);
}
@Override
public ColumnMajorMatrix toColumnMajorMatrix() {
- throw new UnsupportedOperationException("Not yet supported");
+ final int nnz = elements.size();
+ final int[] rows = new int[nnz];
+ final int[] cols = new int[nnz];
+ final double[] data = new double[nnz];
+
+ final IMapIterator itor = elements.entries();
+ for (int i = 0; i < nnz; i++) {
+ if (itor.next() == -1) {
+ throw new IllegalStateException("itor.next() returns -1 where i=" + i);
+ }
+ long k = itor.getKey();
+ rows[i] = Primitives.getHigh(k);
+ cols[i] = Primitives.getLow(k);
+ data[i] = itor.getValue();
+ }
+
+ return MatrixUtils.coo2csc(rows, cols, data, numRows, numColumns, true);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java
new file mode 100644
index 0000000..3aa1dc9
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse.floats;
+
+import hivemall.math.matrix.ColumnMajorFloatMatrix;
+import hivemall.math.matrix.builders.CSCMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Compressed Sparse Column matrix optimized for column major access.
+ *
+ * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
+ */
+public final class CSCFloatMatrix extends ColumnMajorFloatMatrix {
+
+ @Nonnull
+ private final int[] columnPointers;
+ @Nonnull
+ private final int[] rowIndices;
+ @Nonnull
+ private final float[] values;
+
+ private final int numRows;
+ private final int numColumns;
+ private final int nnz;
+
+ public CSCFloatMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndices,
+ @Nonnull float[] values, int numRows, int numColumns) {
+ super();
+ Preconditions.checkArgument(columnPointers.length >= 1,
+ "rowPointers must be greather than 0: " + columnPointers.length);
+ Preconditions.checkArgument(rowIndices.length == values.length, "#rowIndices ("
+ + rowIndices.length + ") must be equals to #values (" + values.length + ")");
+ this.columnPointers = columnPointers;
+ this.rowIndices = rowIndices;
+ this.values = values;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(final int row) {
+ checkRowIndex(row, numRows);
+
+ return ArrayUtils.count(rowIndices, row);
+ }
+
+ @Override
+ public double[] getRow(final int index) {
+ checkRowIndex(index, numRows);
+
+ final double[] row = new double[numColumns];
+
+ final int numCols = columnPointers.length - 1;
+ for (int j = 0; j < numCols; j++) {
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ row[j] = values[k];
+ }
+ }
+
+ return row;
+ }
+
+ @Override
+ public double[] getRow(final int index, @Nonnull final double[] dst) {
+ checkRowIndex(index, numRows);
+
+ final int last = Math.min(dst.length, columnPointers.length - 1);
+ for (int j = 0; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ dst[j] = values[k];
+ } else {
+ dst[j] = 0.d;
+ }
+ }
+ for (int j = last; j < dst.length; j++) {
+ dst[j] = 0.d;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public float[] getRow(final int index, @Nonnull final float[] dst) {
+ checkRowIndex(index, numRows);
+
+ final int last = Math.min(dst.length, columnPointers.length - 1);
+ for (int j = 0; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ dst[j] = values[k];
+ } else {
+ dst[j] = 0.f;
+ }
+ }
+ for (int j = last; j < dst.length; j++) {
+ dst[j] = 0.f;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int j = 0, last = columnPointers.length - 1; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndices, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ float v = values[k];
+ row.set(j, v);
+ }
+ }
+ }
+
+ @Override
+ public float get(final int row, final int col, final float defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public float getAndSet(final int row, final int col, final float value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ float old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final float value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = columnPointers[col];
+ int rightEx = columnPointers[col + 1];
+ final int index = Arrays.binarySearch(rowIndices, leftIn, rightEx, row);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(final int row1, final int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+
+ if (nullOutput) {
+ for (int row = 0, i = startIn; row < numRows; row++) {
+ if (i < endEx && row == rowIndices[i]) {
+ float v = values[i++];
+ procedure.apply(row, v);
+ } else {
+ procedure.apply(row, 0.f);
+ }
+ }
+ } else {
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndices[j];
+ float v = values[j];
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndices[j];
+ final float v = values[j];
+ if (v != 0.f) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public CSRFloatMatrix toRowMajorMatrix() {
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndices = new int[nnz];
+ final float[] csrValues = new float[nnz];
+
+ // compute nnz per for each row
+ for (int i = 0; i < rowIndices.length; i++) {
+ rowPointers[rowIndices[i]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ for (int j = 0; j < numColumns; j++) {
+ for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) {
+ int col = rowIndices[i];
+ int dst = rowPointers[col];
+
+ colIndices[dst] = j;
+ csrValues[dst] = values[i];
+
+ rowPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+
+ return new CSRFloatMatrix(rowPointers, colIndices, csrValues, numColumns);
+ }
+
+ @Override
+ public CSCMatrixBuilder builder() {
+ return new CSCMatrixBuilder(nnz);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java
new file mode 100644
index 0000000..3dd44de
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse.floats;
+
+import hivemall.math.matrix.RowMajorFloatMatrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Compressed Sparse Row Matrix optimized for row major access.
+ *
+ * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
+ * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
+ */
+public final class CSRFloatMatrix extends RowMajorFloatMatrix {
+
+ @Nonnull
+ private final int[] rowPointers;
+ @Nonnull
+ private final int[] columnIndices;
+ @Nonnull
+ private final float[] values;
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+ @Nonnegative
+ private final int nnz;
+
+ public CSRFloatMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices,
+ @Nonnull float[] values, @Nonnegative int numColumns) {
+ super();
+ Preconditions.checkArgument(rowPointers.length >= 1,
+ "rowPointers must be greather than 0: " + rowPointers.length);
+ Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices ("
+ + columnIndices.length + ") must be equals to #values (" + values.length + ")");
+ this.rowPointers = rowPointers;
+ this.columnIndices = columnIndices;
+ this.values = values;
+ this.numRows = rowPointers.length - 1;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ checkRowIndex(row, numRows);
+
+ int columns = rowPointers[row + 1] - rowPointers[row];
+ return columns;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ final double[] row = new double[numColumns];
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, float value) {
+ row[col] = value;
+ }
+ });
+ return row;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index, @Nonnull final double[] dst) {
+ Arrays.fill(dst, 0.d);
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, float value) {
+ checkColIndex(col, numColumns);
+ dst[col] = value;
+ }
+ });
+ return dst;
+ }
+
+ @Override
+ public float[] getRow(@Nonnegative final int index, @Nonnull final float[] dst) {
+ Arrays.fill(dst, 0.f);
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, float value) {
+ checkColIndex(col, numColumns);
+ dst[col] = value;
+ }
+ });
+ return dst;
+ }
+
+ @Override
+ public float get(@Nonnegative final int row, @Nonnegative final int col,
+ final float defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ float old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = rowPointers[row];
+ int rightEx = rowPointers[row + 1];
+ final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(int row1, int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ if (nullOutput) {
+ for (int col = 0, j = startIn; col < numColumns; col++) {
+ if (j < endEx && col == columnIndices[j]) {
+ float v = values[j++];
+ procedure.apply(col, v);
+ } else {
+ procedure.apply(col, 0.f);
+ }
+ }
+ } else {
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i], values[i]);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+ for (int i = startIn; i < endEx; i++) {
+ int col = columnIndices[i];
+ final float v = values[i];
+ if (v != 0.f) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i]);
+ }
+ }
+
+ @Nonnull
+ public CSCFloatMatrix toColumnMajorMatrix() {
+ final int[] columnPointers = new int[numColumns + 1];
+ final int[] rowIndices = new int[nnz];
+ final float[] cscValues = new float[nnz];
+
+ // compute nnz per for each column
+ for (int j = 0; j < columnIndices.length; j++) {
+ columnPointers[columnIndices[j]]++;
+ }
+ for (int j = 0, sum = 0; j < numColumns; j++) {
+ int curr = columnPointers[j];
+ columnPointers[j] = sum;
+ sum += curr;
+ }
+ columnPointers[numColumns] = nnz;
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = rowPointers[i], last = rowPointers[i + 1]; j < last; j++) {
+ int col = columnIndices[j];
+ int dst = columnPointers[col];
+
+ rowIndices[dst] = i;
+ cscValues[dst] = values[j];
+
+ columnPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int j = 0, last = 0; j <= numColumns; j++) {
+ int tmp = columnPointers[j];
+ columnPointers[j] = last;
+ last = tmp;
+ }
+
+ return new CSCFloatMatrix(columnPointers, rowIndices, cscValues, numRows, numColumns);
+ }
+
+ @Override
+ public CSRMatrixBuilder builder() {
+ return new CSRMatrixBuilder(values.length);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
new file mode 100644
index 0000000..10929fb
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
@@ -0,0 +1,401 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse.floats;
+
+import hivemall.annotations.Experimental;
+import hivemall.math.matrix.AbstractMatrix;
+import hivemall.math.matrix.FloatMatrix;
+import hivemall.math.matrix.MatrixUtils;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Dictionary Of Keys based sparse matrix.
+ *
+ * This is an efficient structure for constructing a sparse matrix incrementally.
+ */
+@Experimental
+public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix {
+
+ @Nonnull
+ private final Long2FloatOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public DoKFloatMatrix() {
+ this(0, 0);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
+ @Nonnegative float sparsity) {
+ super();
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ this.nnz = 0;
+ }
+
+ public DoKFloatMatrix(@Nonnegative int initSize) {
+ super();
+ int initialCapacity = Math.max(initSize, 16384);
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = 0;
+ this.numColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ int count = 0;
+ for (int j = 0; j < numColumns; j++) {
+ long index = index(row, j);
+ if (elements.containsKey(index)) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ double[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ float v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public float[] getRow(@Nonnegative final int row, @Nonnull final float[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ float v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int col = 0; col < numColumns; col++) {
+ long k = index(index, col);
+ final float v = elements.get(k, 0.f);
+ if (v != 0.f) {
+ row.set(col, v);
+ }
+ }
+ }
+
+ @Override
+ public float get(@Nonnegative final int row, @Nonnegative final int col,
+ final float defaultValue) {
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return;
+ }
+
+ if (elements.put(index, value, 0.f) == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ }
+
+ @Override
+ public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return 0.f;
+ }
+
+ final float old = elements.put(index, value, 0.f);
+ if (old == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ return old;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ for (int j = 0; j < numColumns; j++) {
+ final long i1 = index(row1, j);
+ final long i2 = index(row2, j);
+
+ final int k1 = elements._findKey(i1);
+ final int k2 = elements._findKey(i2);
+
+ if (k1 >= 0) {
+ if (k2 >= 0) {
+ float v1 = elements._get(k1);
+ float v2 = elements._set(k2, v1);
+ elements._set(k1, v2);
+ } else {// k1>=0 and k2<0
+ float v1 = elements._remove(k1);
+ elements.put(i2, v1);
+ }
+ } else if (k2 >= 0) {// k2>=0 and k1 < 0
+ float v2 = elements._remove(k2);
+ elements.put(i1, v2);
+ } else {//k1<0 and k2<0
+ continue;
+ }
+ }
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, 0.f);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key != -1) {
+ procedure.apply(col);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, 0.f);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
+ if (nnz == 0) {
+ return;
+ }
+ final IMapIterator itor = elements.entries();
+ while (itor.next() != -1) {
+ long k = itor.getKey();
+ int row = Primitives.getHigh(k);
+ int col = Primitives.getLow(k);
+ float value = itor.getValue();
+ procedure.apply(row, col, value);
+ }
+ }
+
+ @Override
+ public CSRFloatMatrix toRowMajorMatrix() {
+ final int nnz = elements.size();
+ final int[] rows = new int[nnz];
+ final int[] cols = new int[nnz];
+ final float[] data = new float[nnz];
+
+ final IMapIterator itor = elements.entries();
+ for (int i = 0; i < nnz; i++) {
+ if (itor.next() == -1) {
+ throw new IllegalStateException("itor.next() returns -1 where i=" + i);
+ }
+ long k = itor.getKey();
+ rows[i] = Primitives.getHigh(k);
+ cols[i] = Primitives.getLow(k);
+ data[i] = itor.getValue();
+ }
+
+ return MatrixUtils.coo2csr(rows, cols, data, numRows, numColumns, true);
+ }
+
+ @Override
+ public CSCFloatMatrix toColumnMajorMatrix() {
+ final int nnz = elements.size();
+ final int[] rows = new int[nnz];
+ final int[] cols = new int[nnz];
+ final float[] data = new float[nnz];
+
+ final IMapIterator itor = elements.entries();
+ for (int i = 0; i < nnz; i++) {
+ if (itor.next() == -1) {
+ throw new IllegalStateException("itor.next() returns -1 where i=" + i);
+ }
+ long k = itor.getKey();
+ rows[i] = Primitives.getHigh(k);
+ cols[i] = Primitives.getLow(k);
+ data[i] = itor.getValue();
+ }
+
+ return MatrixUtils.coo2csc(rows, cols, data, numRows, numColumns, true);
+ }
+
+ @Override
+ public DoKMatrixBuilder builder() {
+ return new DoKMatrixBuilder(elements.size());
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/vector/AbstractVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java
index 88bed7b..7c4579f 100644
--- a/core/src/main/java/hivemall/math/vector/AbstractVector.java
+++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java
@@ -29,6 +29,16 @@ public abstract class AbstractVector implements Vector {
return get(index, 0.d);
}
+ @Override
+ public float get(@Nonnegative final int index, final float defaultValue) {
+ return (float) get(index, (double) defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int index, final float value) {
+ set(index, (double) value);
+ }
+
protected static final void checkIndex(final int index) {
if (index < 0) {
throw new IndexOutOfBoundsException("Invalid index " + index);