You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/07/10 05:58:48 UTC

[incubator-hivemall] branch master updated: Refactor Matrix module for NNZ and zero value handling

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 17fbfcd  Refactor Matrix module for NNZ and zero value handling
17fbfcd is described below

commit 17fbfcdb038941b96b1cffd7b9b9827bce246f4a
Author: Makoto Yui <my...@apache.org>
AuthorDate: Wed Jul 10 14:58:39 2019 +0900

    Refactor Matrix module for NNZ and zero value handling
    
    ## What changes were proposed in this pull request?
    
    Refactor Matrix module for NNZ and zero value handling.
    
    ## What type of PR is it?
    
    Hot Fix, Refactoring
    
    ## What is the Jira issue?
    
    no JIRA issue
    
    ## How was this patch tested?
    
    Unit tests
    
    ## Checklist
    
    (Please remove this section if not needed; check `x` for YES, blank for NO)
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [ ] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #196 from myui/refactor_randomforest.
---
 .../java/hivemall/math/matrix/AbstractMatrix.java  | 24 ++++++++++++
 .../math/matrix/builders/CSCMatrixBuilder.java     |  4 +-
 .../math/matrix/builders/CSRMatrixBuilder.java     |  4 +-
 .../builders/ColumnMajorDenseMatrixBuilder.java    |  4 +-
 .../math/matrix/builders/DoKMatrixBuilder.java     |  2 +
 .../math/matrix/builders/MatrixBuilder.java        |  6 +++
 .../builders/RowMajorDenseMatrixBuilder.java       | 11 +++++-
 .../hivemall/math/matrix/sparse/CSCMatrix.java     |  2 +-
 .../hivemall/math/matrix/sparse/DoKMatrix.java     | 39 ++++++--------------
 .../math/matrix/sparse/floats/DoKFloatMatrix.java  | 37 +++++--------------
 .../hivemall/math/matrix/MatrixBuilderTest.java    | 43 ++++++++++++++++------
 .../hivemall/math/matrix/sparse/DoKMatrixTest.java | 40 ++++++++++++++++++++
 12 files changed, 145 insertions(+), 71 deletions(-)

diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
index 627ef9c..b9abd0f 100644
--- a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
@@ -109,4 +109,28 @@ public abstract class AbstractMatrix implements Matrix {
         throw new UnsupportedOperationException("Not yet supported");
     }
 
+    @Override
+    public String toString() {
+        final int printSize = 7;
+        final StringBuilder buf = new StringBuilder();
+
+        final int rows = numRows();
+        final int cols = numColumns();
+
+        final String newline = cols > printSize ? "...\n" : "\n";
+
+        for (int i = 0, maxRows = Math.min(printSize, rows); i < maxRows; i++) {
+            for (int j = 0, maxCols = Math.min(printSize, cols); j < maxCols; j++) {
+                buf.append(String.format("%8.4f  ", get(i, j)));
+            }
+            buf.append(newline);
+        }
+
+        if (rows > printSize) {
+            buf.append("  ...\n");
+        }
+
+        return buf.toString();
+    }
+
 }
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
index 5c546d5..ffe0cba 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
@@ -56,9 +56,11 @@ public final class CSCMatrixBuilder extends MatrixBuilder {
 
     @Override
     public CSCMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+        checkColIndex(col);
+
         rows.add(row);
         cols.add(col);
-        values.add((float) value);
+        values.add(value);
         this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         return this;
     }
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
index 2467056..83be589 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
@@ -57,13 +57,15 @@ public final class CSRMatrixBuilder extends MatrixBuilder {
 
     @Override
     public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) {
+        checkColIndex(col);
+
+        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         if (value == 0.d) {
             return this;
         }
 
         columnIndices.add(col);
         values.add(value);
-        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         return this;
     }
 
diff --git a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
index b830219..9130efb 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
@@ -51,6 +51,9 @@ public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder {
     @Override
     public ColumnMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col,
             final double value) {
+        checkColIndex(col);
+
+        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         if (value == 0.d) {
             return this;
         }
@@ -61,7 +64,6 @@ public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder {
             col2rows.put(col, rows);
         }
         rows.put(row, value);
-        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         nnz++;
         return this;
     }
diff --git a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
index 556a8d8..f6e9781 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
@@ -44,6 +44,8 @@ public final class DoKMatrixBuilder extends MatrixBuilder {
 
     @Override
     public DoKMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+        checkColIndex(col);
+
         matrix.set(row, col, value);
         return this;
     }
diff --git a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
index 7688086..5b10c43 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
@@ -27,6 +27,12 @@ public abstract class MatrixBuilder {
 
     public MatrixBuilder() {}
 
+    protected static final void checkColIndex(final int col) {
+        if (col < 0) {
+            throw new IllegalArgumentException("Found negative column index: " + col);
+        }
+    }
+
     public void nextRow(@Nonnull final double[] row) {
         for (int col = 0; col < row.length; col++) {
             nextColumn(col, row[col]);
diff --git a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
index b6d0588..5b32101 100644
--- a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
+++ b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
@@ -47,6 +47,9 @@ public final class RowMajorDenseMatrixBuilder extends MatrixBuilder {
 
     @Override
     public RowMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+        checkColIndex(col);
+
+        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
         if (value == 0.d) {
             return this;
         }
@@ -59,12 +62,18 @@ public final class RowMajorDenseMatrixBuilder extends MatrixBuilder {
     public RowMajorDenseMatrixBuilder nextRow() {
         double[] row = rowProbe.toArray();
         rowProbe.clear();
-        nextRow(row);
+        rows.add(row);
+        //this.maxNumColumns = Math.max(row.length, maxNumColumns);
         return this;
     }
 
     @Override
     public void nextRow(@Nonnull double[] row) {
+        for (double v : row) {
+            if (v != 0.d) {
+                nnz++;
+            }
+        }
         rows.add(row);
         this.maxNumColumns = Math.max(row.length, maxNumColumns);
     }
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
index 317a0f1..4ea4ad5 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -158,7 +158,7 @@ public final class CSCMatrix extends ColumnMajorMatrix {
     public double get(final int row, final int col, final double defaultValue) {
         checkIndex(row, col, numRows, numColumns);
 
-        int index = getIndex(row, col);
+        final int index = getIndex(row, col);
         if (index < 0) {
             return defaultValue;
         }
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
index 8ac0cf9..30c107d 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -35,7 +35,7 @@ import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 
 /**
- * Dictionary Of Keys based sparse matrix.
+ * Dictionary of Keys based sparse matrix.
  *
  * This is an efficient structure for constructing a sparse matrix incrementally.
  */
@@ -48,8 +48,6 @@ public final class DoKMatrix extends AbstractMatrix {
     private int numRows;
     @Nonnegative
     private int numColumns;
-    @Nonnegative
-    private int nnz;
 
     public DoKMatrix() {
         this(0, 0);
@@ -69,7 +67,6 @@ public final class DoKMatrix extends AbstractMatrix {
         elements.defaultReturnValue(0.d);
         this.numRows = numRows;
         this.numColumns = numCols;
-        this.nnz = 0;
     }
 
     public DoKMatrix(@Nonnegative int initSize) {
@@ -79,7 +76,6 @@ public final class DoKMatrix extends AbstractMatrix {
         elements.defaultReturnValue(0.d);
         this.numRows = 0;
         this.numColumns = 0;
-        this.nnz = 0;
     }
 
     @Override
@@ -109,7 +105,7 @@ public final class DoKMatrix extends AbstractMatrix {
 
     @Override
     public int nnz() {
-        return nnz;
+        return elements.size();
     }
 
     @Override
@@ -179,16 +175,10 @@ public final class DoKMatrix extends AbstractMatrix {
     public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
         checkIndex(row, col);
 
-        final long index = index(row, col);
-        if (value == 0.d && elements.containsKey(index) == false) {
-            return;
-        }
-
-        if (elements.put(index, value, 0.d) == 0.d) {
-            nnz++;
-            this.numRows = Math.max(numRows, row + 1);
-            this.numColumns = Math.max(numColumns, col + 1);
-        }
+        long index = index(row, col);
+        elements.put(index, value);
+        this.numRows = Math.max(numRows, row + 1);
+        this.numColumns = Math.max(numColumns, col + 1);
     }
 
     @Override
@@ -196,17 +186,10 @@ public final class DoKMatrix extends AbstractMatrix {
             final double value) {
         checkIndex(row, col);
 
-        final long index = index(row, col);
-        if (value == 0.d && elements.containsKey(index) == false) {
-            return 0.d;
-        }
-
-        final double old = elements.put(index, value, 0.d);
-        if (old == 0.d) {
-            nnz++;
-            this.numRows = Math.max(numRows, row + 1);
-            this.numColumns = Math.max(numColumns, col + 1);
-        }
+        long index = index(row, col);
+        double old = elements.put(index, value);
+        this.numRows = Math.max(numRows, row + 1);
+        this.numColumns = Math.max(numColumns, col + 1);
         return old;
     }
 
@@ -320,7 +303,7 @@ public final class DoKMatrix extends AbstractMatrix {
     }
 
     public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
-        if (nnz == 0) {
+        if (elements.size() == 0) {
             return;
         }
         final IMapIterator itor = elements.entries();
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
index 36b8d7a..9153566 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java
@@ -47,8 +47,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
     private int numRows;
     @Nonnegative
     private int numColumns;
-    @Nonnegative
-    private int nnz;
 
     public DoKFloatMatrix() {
         this(0, 0);
@@ -68,7 +66,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
         elements.defaultReturnValue(0.f);
         this.numRows = numRows;
         this.numColumns = numCols;
-        this.nnz = 0;
     }
 
     public DoKFloatMatrix(@Nonnegative int initSize) {
@@ -78,7 +75,6 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
         elements.defaultReturnValue(0.f);
         this.numRows = 0;
         this.numColumns = 0;
-        this.nnz = 0;
     }
 
     @Override
@@ -108,7 +104,7 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
 
     @Override
     public int nnz() {
-        return nnz;
+        return elements.size();
     }
 
     @Override
@@ -192,16 +188,10 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
     public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
         checkIndex(row, col);
 
-        final long index = index(row, col);
-        if (value == 0.f && elements.containsKey(index) == false) {
-            return;
-        }
-
-        if (elements.put(index, value, 0.f) == 0.f) {
-            nnz++;
-            this.numRows = Math.max(numRows, row + 1);
-            this.numColumns = Math.max(numColumns, col + 1);
-        }
+        long index = index(row, col);
+        elements.put(index, value);
+        this.numRows = Math.max(numRows, row + 1);
+        this.numColumns = Math.max(numColumns, col + 1);
     }
 
     @Override
@@ -209,17 +199,10 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
             final float value) {
         checkIndex(row, col);
 
-        final long index = index(row, col);
-        if (value == 0.f && elements.containsKey(index) == false) {
-            return 0.f;
-        }
-
-        final float old = elements.put(index, value, 0.f);
-        if (old == 0.f) {
-            nnz++;
-            this.numRows = Math.max(numRows, row + 1);
-            this.numColumns = Math.max(numColumns, col + 1);
-        }
+        long index = index(row, col);
+        float old = elements.put(index, value);
+        this.numRows = Math.max(numRows, row + 1);
+        this.numColumns = Math.max(numColumns, col + 1);
         return old;
     }
 
@@ -334,7 +317,7 @@ public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix
 
     @Override
     public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
-        if (nnz == 0) {
+        if (elements.size() == 0) {
             return;
         }
         final IMapIterator itor = elements.entries();
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
index af3f024..0cf65c0 100644
--- a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -18,8 +18,6 @@
  */
 package hivemall.math.matrix;
 
-import hivemall.math.matrix.Matrix;
-import hivemall.math.matrix.RowMajorMatrix;
 import hivemall.math.matrix.builders.CSCMatrixBuilder;
 import hivemall.math.matrix.builders.CSRMatrixBuilder;
 import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder;
@@ -159,6 +157,8 @@ public class MatrixBuilderTest {
     public void testCSC2CSR() {
         CSCMatrix csc = cscMatrixFromLibSVM();
         RowMajorMatrix csr = csc.toRowMajorMatrix();
+        Assert.assertEquals(csc.toString(), csr.toString());
+
         Assert.assertTrue(csr instanceof CSRMatrix);
         Assert.assertEquals(6, csr.numRows());
         Assert.assertEquals(6, csr.numColumns());
@@ -295,14 +295,15 @@ public class MatrixBuilderTest {
     @Test
     public void testReadOnlyDenseMatrix2dSparseInput() {
         Matrix matrix = denseMatrixSparseInput();
-        Assert.assertEquals(6, matrix.numRows());
-        Assert.assertEquals(6, matrix.numColumns());
+        Assert.assertEquals(7, matrix.numRows());
+        Assert.assertEquals(7, matrix.numColumns());
         Assert.assertEquals(4, matrix.numColumns(0));
         Assert.assertEquals(3, matrix.numColumns(1));
         Assert.assertEquals(6, matrix.numColumns(2));
         Assert.assertEquals(5, matrix.numColumns(3));
         Assert.assertEquals(6, matrix.numColumns(4));
         Assert.assertEquals(6, matrix.numColumns(5));
+        Assert.assertEquals(2, matrix.numColumns(6));
 
         Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
         Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
@@ -318,12 +319,29 @@ public class MatrixBuilderTest {
         Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
         Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
         Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+        Assert.assertEquals(77d, matrix.get(6, 1), 0.d);
 
         Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
         Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
         Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
         Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+        Assert.assertEquals(0.d, matrix.get(6, 6), 0.d);
+    }
+
+    @Test(expected = IndexOutOfBoundsException.class)
+    public void testReadOnlyDenseMatrix2dSparseColOutOfBounds() {
+        Matrix matrix = denseMatrixSparseInput();
+        Assert.assertEquals(7, matrix.numRows());
+        Assert.assertEquals(7, matrix.numColumns());
+        matrix.get(6, 7);
+    }
+
+    @Test(expected = IndexOutOfBoundsException.class)
+    public void testReadOnlyDenseMatrix2dSparseRowOutOfBounds() {
+        Matrix matrix = denseMatrixSparseInput();
+        Assert.assertEquals(7, matrix.numRows());
+        Assert.assertEquals(7, matrix.numColumns());
+        matrix.get(7, 6);
     }
 
     @Test
@@ -418,6 +436,7 @@ public class MatrixBuilderTest {
     public void testDenseMatrixColumnMajor2RowMajor() {
         ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
         RowMajorDenseMatrix2d rowMatrix = colMatrix.toRowMajorMatrix();
+        Assert.assertEquals(colMatrix.toString(), rowMatrix.toString());
 
         Assert.assertEquals(6, rowMatrix.numRows());
         Assert.assertEquals(6, rowMatrix.numColumns());
@@ -612,12 +631,13 @@ public class MatrixBuilderTest {
 
     private static RowMajorDenseMatrix2d denseMatrixSparseInput() {
         /*
-        11  12  13  14  0   0
-        0   22  23  0   0   0
-        0   0   33  34  35  36
-        0   0   0   44  45  0
-        0   0   0   0   0   56
-        0   0   0   0   0   66
+        11  12  13  14  0   0   0
+        0   22  23  0   0   0   0
+        0   0   33  34  35  36  0
+        0   0   0   44  45  0   0
+        0   0   0   0   0   56  0
+        0   0   0   0   0   66  0
+        0   77  0   0   0   0   0
         */
         RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
         builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
@@ -626,6 +646,7 @@ public class MatrixBuilderTest {
         builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
         builder.nextColumn(5, 56).nextRow();
         builder.nextColumn(5, 66).nextRow();
+        builder.nextColumn(1, 77).nextColumn(6, 0).nextRow();
         return builder.buildMatrix();
     }
 
diff --git a/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java b/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
index be6b424..0ae0a8c 100644
--- a/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
+++ b/core/src/test/java/hivemall/math/matrix/sparse/DoKMatrixTest.java
@@ -18,6 +18,10 @@
  */
 package hivemall.math.matrix.sparse;
 
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Primitives;
+
+import java.util.HashSet;
 import java.util.Random;
 
 import org.junit.Assert;
@@ -37,7 +41,43 @@ public class DoKMatrixTest {
             matrix.set(row, col, v);
             Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
         }
+    }
+
+    @Test
+    public void testNumRowsNumCols() {
+        DoKMatrix matrix = new DoKMatrix();
+        Random rnd = new Random(43);
+        HashSet<Long> bitset = new HashSet<>(100000);
+
+        int numRows = -1, numCols = -1;
+        for (int i = 0; i < 100000; i++) {
+            int row = Math.abs(rnd.nextInt());
+            int col = Math.abs(rnd.nextInt());
+            numRows = Math.max(row + 1, numRows);
+            numCols = Math.max(col + 1, numCols);
+            double v = rnd.nextDouble();
+            if (v >= 0.8) {
+                v = 0.d;
+            }
+            matrix.getAndSet(row, col, v);
+            bitset.add(Primitives.toLong(row, col));
+            Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
+        }
 
+        Assert.assertEquals(numRows, matrix.numRows());
+        Assert.assertEquals(numCols, matrix.numColumns());
+        Assert.assertEquals(bitset.size(), matrix.nnz());
+    }
+
+    @Test
+    public void testEmpty() {
+        DoKMatrix matrix = new DoKMatrix();
+        matrix.eachNonZeroCell(new VectorProcedure() {
+            @Override
+            public void apply(int i, int j, double value) {
+                Assert.fail("should not be called");
+            }
+        });
     }
 
 }