You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/10/31 10:58:32 UTC

[incubator-hivemall] 02/03: Added SparseDMatrixBuilder

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git

commit 3985bbf1b74ea0414850b607d1ad9e205db6757c
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Oct 31 19:17:54 2019 +0900

    Added SparseDMatrixBuilder
---
 .../xgboost/utils/SparseDMatrixBuilder.java        | 79 ++++++++++++++++++++++
 1 file changed, 79 insertions(+)

diff --git a/xgboost/src/main/java/hivemall/xgboost/utils/SparseDMatrixBuilder.java b/xgboost/src/main/java/hivemall/xgboost/utils/SparseDMatrixBuilder.java
new file mode 100644
index 0000000..6f8e739
--- /dev/null
+++ b/xgboost/src/main/java/hivemall/xgboost/utils/SparseDMatrixBuilder.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.xgboost.utils;
+
+import hivemall.utils.collections.lists.FloatArrayList;
+import hivemall.utils.collections.lists.LongArrayList;
+import matrix4j.utils.collections.lists.IntArrayList;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public class SparseDMatrixBuilder {
+
+    @Nonnull
+    private final LongArrayList rowPointers;
+    @Nonnull
+    private final IntArrayList columnIndices;
+    @Nonnull
+    private final FloatArrayList values;
+
+    private int maxNumColumns;
+
+    public SparseDMatrixBuilder(@Nonnegative int initSize) {
+        this.rowPointers = new LongArrayList(initSize + 1);
+        rowPointers.add(0);
+        this.columnIndices = new IntArrayList(initSize);
+        this.values = new FloatArrayList(initSize);
+        this.maxNumColumns = 0;
+    }
+
+    public SparseDMatrixBuilder nextRow() {
+        int ptr = values.size();
+        rowPointers.add(ptr);
+        return this;
+    }
+
+    private static final void checkColIndex(final int col) {
+        if (col < 0) {
+            throw new IllegalArgumentException("Found negative column index: " + col);
+        }
+    }
+
+    public SparseDMatrixBuilder nextColumn(@Nonnegative int col, float value) {
+        checkColIndex(col);
+
+        this.maxNumColumns = Math.max(col + 1, maxNumColumns);
+        if (value == 0.d) {
+            return this;
+        }
+
+        columnIndices.add(col);
+        values.add(value);
+        return this;
+    }
+
+    @Nonnull
+    public DMatrix buildMatrix() throws XGBoostError {
+        return new DMatrix(rowPointers.toArray(true), columnIndices.toArray(true),
+            values.toArray(true), DMatrix.SparseType.CSR, maxNumColumns);
+    }
+}