You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/07/10 05:58:48 UTC
[incubator-hivemall] branch master updated: Refactor Matrix module
for NNZ and zero value handling
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 17fbfcd Refactor Matrix module for NNZ and zero value handling
17fbfcd is described below
commit 17fbfcdb038941b96b1cffd7b9b9827bce246f4a
Author: Makoto Yui <my...@apache.org>
AuthorDate: Wed Jul 10 14:58:39 2019 +0900
Refactor Matrix module for NNZ and zero value handling
## What changes were proposed in this pull request?
Refactor Matrix module for NNZ and zero value handling.
## What type of PR is it?
Hot Fix, Refactoring
## What is the Jira issue?
no JIRA issue
## How was this patch tested?
Unit tests
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #196 from myui/refactor_randomforest.
---
.../java/hivemall/math/matrix/AbstractMatrix.java | 24 ++++++++++++
.../math/matrix/builders/CSCMatrixBuilder.java | 4 +-
.../math/matrix/builders/CSRMatrixBuilder.java | 4 +-
.../builders/ColumnMajorDenseMatrixBuilder.java | 4 +-
.../math/matrix/builders/DoKMatrixBuilder.java | 2 +
.../math/matrix/builders/MatrixBuilder.java | 6 +++
.../builders/RowMajorDenseMatrixBuilder.java | 11 +++++-
.../hivemall/math/matrix/sparse/CSCMatrix.java | 2 +-
.../hivemall/math/matrix/sparse/DoKMatrix.java | 39 ++++++--------------
.../math/matrix/sparse/floats/DoKFloatMatrix.java | 37 +++++--------------
.../hivemall/math/matrix/MatrixBuilderTest.java | 43 ++++++++++++++++------
.../hivemall/math/matrix/sparse/DoKMatrixTest.java | 40 ++++++++++++++++++++
12 files changed, 145 insertions(+), 71 deletions(-)
diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
index 627ef9c..b9abd0f 100644
--- a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
@@ -109,4 +109,28 @@ public abstract class AbstractMatrix implements Matrix {
throw new UnsupportedOperationException("Not yet supported");
}
+ @Override
+ public String toString() {
+ final int printSize = 7;
+ final StringBuilder buf = new StringBuilder();
+
+ final int rows = numRows();
+ final int cols = numColumns();
+
+ final String newline = cols > printSize ? "...\n" : "\n";
+
+ for (int i = 0, maxRows = Math.min(printSize, rows); i < maxRows; i++) {
+ for (int j = 0, maxCols = Math.min(printSize, cols); j < maxCols; j++) {
+ buf.append(String.format("%8.4f ", get(i, j)));
+ }
+ buf.append(newline);
+ }
+
+ if (rows > printSize) {
+ buf.append(" ...\n");
+ }
+
+ return buf.toString();
+ }
+
}
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
index 5c546d5..ffe0cba 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
@@ -56,9 +56,11 @@ public final class CSCMatrixBuilder extends MatrixBuilder {
@Override
public CSCMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ checkColIndex(col);
+
rows.add(row);
cols.add(col);
- values.add((float) value);
+ values.add(value);
this.maxNumColumns = Math.max(col + 1, maxNumColumns);
return this;
}
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
index 2467056..83be589 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
@@ -57,13 +57,15 @@ public final class CSRMatrixBuilder extends MatrixBuilder {
@Override
public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) {
+ checkColIndex(col);
+
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
if (value == 0.d) {
return this;
}
columnIndices.add(col);
values.add(value);
- this.maxNumColumns = Math.max(col + 1, maxNumColumns);
return this;
}
diff --git a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
index b830219..9130efb 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
@@ -51,6 +51,9 @@ public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder {
@Override
public ColumnMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col,
final double value) {
+ checkColIndex(col);
+
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
if (value == 0.d) {
return this;
}
@@ -61,7 +64,6 @@ public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder {
col2rows.put(col, rows);
}
rows.put(row, value);
- this.maxNumColumns = Math.max(col + 1, maxNumColumns);
nnz++;
return this;
}
diff --git a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
index 556a8d8..f6e9781 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
@@ -44,6 +44,8 @@ public final class DoKMatrixBuilder extends MatrixBuilder {
@Override
public DoKMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ checkColIndex(col);
+
matrix.set(row, col, value);
return this;
}
diff --git a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
index 7688086..5b10c43 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
@@ -27,6 +27,12 @@ public abstract class MatrixBuilder {
public MatrixBuilder() {}
+ protected static final void checkColIndex(final int col) {
+ if (col < 0) {
+ throw new IllegalArgumentException("Found negative column index: " + col);
+ }
+ }
+
public void nextRow(@Nonnull final double[] row) {
for (int col = 0; col < row.length; col++) {
nextColumn(col, row[col]);
diff --git a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
index b6d0588..5b32101 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
@@ -47,6 +47,9 @@ public final class RowMajorDenseMatrixBuilder extends MatrixBuilder {
@Override
public RowMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ checkColIndex(col);
+
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
if (value == 0.d) {
return this;
}
@@ -59,12 +62,18 @@ public final class RowMajorDenseMatrixBuilder extends MatrixBuilder {
public RowMajorDenseMatrixBuilder nextRow() {
double[] row = rowProbe.toArray();
rowProbe.clear();
- nextRow(row);
+ rows.add(row);
+ //this.maxNumColumns = Math.max(row.length, maxNumColumns);
return this;
}
@Override
public void nextRow(@Nonnull double[] row) {
+ for (double v : row) {
+ if (v != 0.d) {
+ nnz++;
+ }
+ }
rows.add(row);
this.maxNumColumns = Math.max(row.length, maxNumColumns);
}
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
index 317a0f1..4ea4ad5 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -158,7 +158,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
public double get(final int row, final int col, final double defaultValue) {
checkIndex(row, col, numRows, numColumns);
- int index = getIndex(row, col);
+ final int index = getIndex(row, col);
if (index < 0) {
return defaultValue;
}
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
index 8ac0cf9..30c107d 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -35,7 +35,7 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
- * Dictionary Of Keys based sparse matrix.
+ * Dictionary of Keys based sparse matrix.
*
* This is an efficient structure for constructing a sparse matrix incrementally.
*/
@@ -48,8 +48,6 @@ public final class DoKMatrix extends AbstractMatrix {
private int numRows;
@Nonnegative
private int numColumns;
- @Nonnegative
- private int nnz;
public DoKMatrix() {
this(0, 0);
@@ -69,7 +67,6 @@ public final class DoKMatrix extends AbstractMatrix {
elements.defaultReturnValue(0.d);
this.numRows = numRows;
this.numColumns = numCols;
- this.nnz = 0;
}
public DoKMatrix(@Nonnegative int initSize) {
@@ -79,7 +76,6 @@ public final class DoKMatrix extends AbstractMatrix {
elements.defaultReturnValue(0.d);
this.numRows = 0;
this.numColumns = 0;
- this.nnz = 0;
}
@Override
@@ -109,7 +105,7 @@ public final class DoKMatrix extends AbstractMatrix {
@Override
public int nnz() {
- return nnz;
+ return elements.size();
}
@Override
@@ -179,16 +175,10 @@ public final class DoKMatrix extends AbstractMatrix {
public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
checkIndex(row, col);
- final long index = index(row, col);
- if (value == 0.d && elements.containsKey(index) == false) {
- return;
- }
-
- if (elements.put(index, value, 0.d) == 0.d) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
+ long index = index(row, col);
+ elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
}
@Override
@@ -196,17 +186,10 @@ public final class DoKMatrix extends AbstractMatrix {
final double value) {
checkIndex(row, col);
- final long index = index(row, col);
- if (value == 0.d && elements.containsKey(index) == false) {
- return 0.d;
- }
-
- final double old = elements.put(index, value, 0.d);
- if (old == 0.d) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
+ long index = index(row, col);
+ double old = elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
return old;
}
@@ -320,7 +303,7 @@ public final class DoKMatrix extends AbstractMatrix {
}
public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
- if (nnz == 0) {
+ if (elements.size() == 0) {
return;
}
final IMapIterator itor = elements.entries();
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
index 36b8d7a..9153566 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
@@ -47,8 +47,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
private int numRows;
@Nonnegative
private int numColumns;
- @Nonnegative
- private int nnz;
public DoKFloatMatrix() {
this(0, 0);
@@ -68,7 +66,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
elements.defaultReturnValue(0.f);
this.numRows = numRows;
this.numColumns = numCols;
- this.nnz = 0;
}
public DoKFloatMatrix(@Nonnegative int initSize) {
@@ -78,7 +75,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
elements.defaultReturnValue(0.f);
this.numRows = 0;
this.numColumns = 0;
- this.nnz = 0;
}
@Override
@@ -108,7 +104,7 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
@Override
public int nnz() {
- return nnz;
+ return elements.size();
}
@Override
@@ -192,16 +188,10 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
checkIndex(row, col);
- final long index = index(row, col);
- if (value == 0.f && elements.containsKey(index) == false) {
- return;
- }
-
- if (elements.put(index, value, 0.f) == 0.f) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
+ long index = index(row, col);
+ elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
}
@Override
@@ -209,17 +199,10 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
final float value) {
checkIndex(row, col);
- final long index = index(row, col);
- if (value == 0.f && elements.containsKey(index) == false) {
- return 0.f;
- }
-
- final float old = elements.put(index, value, 0.f);
- if (old == 0.f) {
- nnz++;
- this.numRows = Math.max(numRows, row + 1);
- this.numColumns = Math.max(numColumns, col + 1);
- }
+ long index = index(row, col);
+ float old = elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
return old;
}
@@ -334,7 +317,7 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
@Override
public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
- if (nnz == 0) {
+ if (elements.size() == 0) {
return;
}
final IMapIterator itor = elements.entries();
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
index af3f024..0cf65c0 100644
--- a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -18,8 +18,6 @@
*/
package hivemall.math.matrix;
-import hivemall.math.matrix.Matrix;
-import hivemall.math.matrix.RowMajorMatrix;
import hivemall.math.matrix.builders.CSCMatrixBuilder;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder;
@@ -159,6 +157,8 @@ public class MatrixBuilderTest {
public void testCSC2CSR() {
CSCMatrix csc = cscMatrixFromLibSVM();
RowMajorMatrix csr = csc.toRowMajorMatrix();
+ Assert.assertEquals(csc.toString(), csr.toString());
+
Assert.assertTrue(csr instanceof CSRMatrix);
Assert.assertEquals(6, csr.numRows());
Assert.assertEquals(6, csr.numColumns());
@@ -295,14 +295,15 @@ public class MatrixBuilderTest {
@Test
public void testReadOnlyDenseMatrix2dSparseInput() {
Matrix matrix = denseMatrixSparseInput();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(7, matrix.numRows());
+ Assert.assertEquals(7, matrix.numColumns());
Assert.assertEquals(4, matrix.numColumns(0));
Assert.assertEquals(3, matrix.numColumns(1));
Assert.assertEquals(6, matrix.numColumns(2));
Assert.assertEquals(5, matrix.numColumns(3));
Assert.assertEquals(6, matrix.numColumns(4));
Assert.assertEquals(6, matrix.numColumns(5));
+ Assert.assertEquals(2, matrix.numColumns(6));
Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
@@ -318,12 +319,29 @@ public class MatrixBuilderTest {
Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+ Assert.assertEquals(77d, matrix.get(6, 1), 0.d);
Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(6, 6), 0.d);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dSparseColOutOfBounds() {
+ Matrix matrix = denseMatrixSparseInput();
+ Assert.assertEquals(7, matrix.numRows());
+ Assert.assertEquals(7, matrix.numColumns());
+ matrix.get(6, 7);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dSparseRowOutOfBounds() {
+ Matrix matrix = denseMatrixSparseInput();
+ Assert.assertEquals(7, matrix.numRows());
+ Assert.assertEquals(7, matrix.numColumns());
+ matrix.get(7, 6);
}
@Test
@@ -418,6 +436,7 @@ public class MatrixBuilderTest {
public void testDenseMatrixColumnMajor2RowMajor() {
ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
RowMajorDenseMatrix2d rowMatrix = colMatrix.toRowMajorMatrix();
+ Assert.assertEquals(colMatrix.toString(), rowMatrix.toString());
Assert.assertEquals(6, rowMatrix.numRows());
Assert.assertEquals(6, rowMatrix.numColumns());
@@ -612,12 +631,13 @@ public class MatrixBuilderTest {
private static RowMajorDenseMatrix2d denseMatrixSparseInput() {
/*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
+ 11 12 13 14 0 0 0
+ 0 22 23 0 0 0 0
+ 0 0 33 34 35 36 0
+ 0 0 0 44 45 0 0
+ 0 0 0 0 0 56 0
+ 0 0 0 0 0 66 0
+ 0 77 0 0 0 0 0
*/
RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
@@ -626,6 +646,7 @@ public class MatrixBuilderTest {
builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
builder.nextColumn(5, 56).nextRow();
builder.nextColumn(5, 66).nextRow();
+ builder.nextColumn(1, 77).nextColumn(6, 0).nextRow();
return builder.buildMatrix();
}
diff --git a/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java b/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
index be6b424..0ae0a8c 100644
--- a/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
+++ b/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
@@ -18,6 +18,10 @@
*/
package hivemall.math.matrix.sparse;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Primitives;
+
+import java.util.HashSet;
import java.util.Random;
import org.junit.Assert;
@@ -37,7 +41,43 @@ public class DoKMatrixTest {
matrix.set(row, col, v);
Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
}
+ }
+
+ @Test
+ public void testNumRowsNumCols() {
+ DoKMatrix matrix = new DoKMatrix();
+ Random rnd = new Random(43);
+ HashSet<Long> bitset = new HashSet<>(100000);
+
+ int numRows = -1, numCols = -1;
+ for (int i = 0; i < 100000; i++) {
+ int row = Math.abs(rnd.nextInt());
+ int col = Math.abs(rnd.nextInt());
+ numRows = Math.max(row + 1, numRows);
+ numCols = Math.max(col + 1, numCols);
+ double v = rnd.nextDouble();
+ if (v >= 0.8) {
+ v = 0.d;
+ }
+ matrix.getAndSet(row, col, v);
+ bitset.add(Primitives.toLong(row, col));
+ Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
+ }
+ Assert.assertEquals(numRows, matrix.numRows());
+ Assert.assertEquals(numCols, matrix.numColumns());
+ Assert.assertEquals(bitset.size(), matrix.nnz());
+ }
+
+ @Test
+ public void testEmpty() {
+ DoKMatrix matrix = new DoKMatrix();
+ matrix.eachNonZeroCell(new VectorProcedure() {
+ @Override
+ public void apply(int i, int j, double value) {
+ Assert.fail("should not be called");
+ }
+ });
}
}