You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:13 UTC
[01/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 7956b5f28 -> 8dc3a024d
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java
new file mode 100644
index 0000000..7951b0b
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.Int2LongOpenHashTable;
+import hivemall.utils.lang.ObjectUtils;
+
+import java.io.IOException;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class Int2LongOpenHashMapTest {
+
+ @Test
+ public void testSize() {
+ Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
+ map.put(1, 3L);
+ Assert.assertEquals(3L, map.get(1));
+ map.put(1, 5L);
+ Assert.assertEquals(5L, map.get(1));
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testDefaultReturnValue() {
+ Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
+ Assert.assertEquals(0, map.size());
+ Assert.assertEquals(-1L, map.get(1));
+ long ret = Long.MIN_VALUE;
+ map.defaultReturnValue(ret);
+ Assert.assertEquals(ret, map.get(1));
+ }
+
+ @Test
+ public void testPutAndGet() {
+ Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1L, map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ long v = map.get(i);
+ Assert.assertEquals(i, v);
+ }
+ }
+
+ @Test
+ public void testSerde() throws IOException, ClassNotFoundException {
+ Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1L, map.put(i, i));
+ }
+
+ byte[] b = ObjectUtils.toCompressedBytes(map);
+ map = new Int2LongOpenHashTable(16384);
+ ObjectUtils.readCompressedObject(b, map);
+
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ long v = map.get(i);
+ Assert.assertEquals(i, v);
+ }
+ }
+
+ @Test
+ public void testIterator() {
+ Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000);
+ Int2LongOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1L, map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ int k = itor.getKey();
+ long v = itor.getValue();
+ Assert.assertEquals(k, v);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java
new file mode 100644
index 0000000..675c586
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.IntOpenHashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntOpenHashMapTest {
+
+ @Test
+ public void testSize() {
+ IntOpenHashMap<Float> map = new IntOpenHashMap<Float>(16384);
+ map.put(1, Float.valueOf(3.f));
+ Assert.assertEquals(Float.valueOf(3.f), map.get(1));
+ map.put(1, Float.valueOf(5.f));
+ Assert.assertEquals(Float.valueOf(5.f), map.get(1));
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testPutAndGet() {
+ IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertNull(map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Integer v = map.get(i);
+ Assert.assertEquals(i, v.intValue());
+ }
+ }
+
+ @Test
+ public void testIterator() {
+ IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(1000);
+ IntOpenHashMap.IMapIterator<Integer> itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertNull(map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ int k = itor.getKey();
+ Integer v = itor.getValue();
+ Assert.assertEquals(k, v.intValue());
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java
new file mode 100644
index 0000000..d5887cd
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.IntOpenHashTable;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntOpenHashTableTest {
+
+ @Test
+ public void testSize() {
+ IntOpenHashTable<Float> map = new IntOpenHashTable<Float>(16384);
+ map.put(1, Float.valueOf(3.f));
+ Assert.assertEquals(Float.valueOf(3.f), map.get(1));
+ map.put(1, Float.valueOf(5.f));
+ Assert.assertEquals(Float.valueOf(5.f), map.get(1));
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testPutAndGet() {
+ IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertNull(map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Integer v = map.get(i);
+ Assert.assertEquals(i, v.intValue());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java
new file mode 100644
index 0000000..a03af53
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.Long2IntOpenHashTable;
+import hivemall.utils.lang.ObjectUtils;
+
+import java.io.IOException;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class Long2IntOpenHashMapTest {
+
+ @Test
+ public void testSize() {
+ Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384);
+ map.put(1L, 3);
+ Assert.assertEquals(3, map.get(1L));
+ map.put(1L, 5);
+ Assert.assertEquals(5, map.get(1L));
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testDefaultReturnValue() {
+ Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384);
+ Assert.assertEquals(0, map.size());
+ Assert.assertEquals(-1, map.get(1L));
+ int ret = Integer.MAX_VALUE;
+ map.defaultReturnValue(ret);
+ Assert.assertEquals(ret, map.get(1L));
+ }
+
+ @Test
+ public void testPutAndGet() {
+ Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1L, map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(i, map.get(i));
+ }
+
+ map.clear();
+ int i = 0;
+ for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) {
+ map.put(j, i);
+ }
+ Assert.assertEquals(i, map.size());
+ i = 0;
+ for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) {
+ Assert.assertEquals(i, map.get(j));
+ }
+ }
+
+ @Test
+ public void testSerde() throws IOException, ClassNotFoundException {
+ Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1, map.put(i, i));
+ }
+
+ byte[] b = ObjectUtils.toCompressedBytes(map);
+ map = new Long2IntOpenHashTable(16384);
+ ObjectUtils.readCompressedObject(b, map);
+
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(i, map.get(i));
+ }
+ }
+
+ @Test
+ public void testIterator() {
+ Long2IntOpenHashTable map = new Long2IntOpenHashTable(1000);
+ Long2IntOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1, map.put(i, i));
+ }
+ Assert.assertEquals(numEntries, map.size());
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ long k = itor.getKey();
+ int v = itor.getValue();
+ Assert.assertEquals(k, v);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java
new file mode 100644
index 0000000..aa48a98
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.OpenHashMap;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import java.util.Map;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class OpenHashMapTest {
+
+ @Test
+ public void testPutAndGet() {
+ Map<Object, Object> map = new OpenHashMap<Object, Object>(16384);
+ final int numEntries = 5000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Object v = map.get(Integer.toString(i));
+ Assert.assertEquals(i, v);
+ }
+ map.put(Integer.toString(1), Integer.MAX_VALUE);
+ Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
+ Assert.assertEquals(numEntries, map.size());
+ }
+
+ @Test
+ public void testIterator() {
+ OpenHashMap<String, Integer> map = new OpenHashMap<String, Integer>(1000);
+ IMapIterator<String, Integer> itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ String k = itor.getKey();
+ Integer v = itor.getValue();
+ Assert.assertEquals(Integer.valueOf(k), v);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+ @Test
+ public void testIteratorGetProbe() {
+ OpenHashMap<String, MutableInt> map = new OpenHashMap<String, MutableInt>(100);
+ IMapIterator<String, MutableInt> itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), new MutableInt(i));
+ }
+
+ final MutableInt probe = new MutableInt();
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ String k = itor.getKey();
+ itor.getValue(probe);
+ Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java
new file mode 100644
index 0000000..708c164
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.lang.ObjectUtils;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import java.io.IOException;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class OpenHashTableTest {
+
+ @Test
+ public void testPutAndGet() {
+ OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
+ final int numEntries = 5000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Object v = map.get(Integer.toString(i));
+ Assert.assertEquals(i, v);
+ }
+ map.put(Integer.toString(1), Integer.MAX_VALUE);
+ Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
+ Assert.assertEquals(numEntries, map.size());
+ }
+
+ @Test
+ public void testIterator() {
+ OpenHashTable<String, Integer> map = new OpenHashTable<String, Integer>(1000);
+ IMapIterator<String, Integer> itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ String k = itor.getKey();
+ Integer v = itor.getValue();
+ Assert.assertEquals(Integer.valueOf(k), v);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+ @Test
+ public void testIteratorGetProbe() {
+ OpenHashTable<String, MutableInt> map = new OpenHashTable<String, MutableInt>(100);
+ IMapIterator<String, MutableInt> itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), new MutableInt(i));
+ }
+
+ final MutableInt probe = new MutableInt();
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ String k = itor.getKey();
+ itor.getValue(probe);
+ Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+ @Test
+ public void testSerDe() throws IOException, ClassNotFoundException {
+ OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
+ final int numEntries = 100000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+
+ byte[] serialized = ObjectUtils.toBytes(map);
+ map = new OpenHashTable<Object, Object>();
+ ObjectUtils.readObject(serialized, map);
+
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Object v = map.get(Integer.toString(i));
+ Assert.assertEquals(i, v);
+ }
+ map.put(Integer.toString(1), Integer.MAX_VALUE);
+ Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
+ Assert.assertEquals(numEntries, map.size());
+ }
+
+
+ @Test
+ public void testCompressedSerDe() throws IOException, ClassNotFoundException {
+ OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
+ final int numEntries = 100000;
+ for (int i = 0; i < numEntries; i++) {
+ map.put(Integer.toString(i), i);
+ }
+
+ byte[] serialized = ObjectUtils.toCompressedBytes(map);
+ map = new OpenHashTable<Object, Object>();
+ ObjectUtils.readCompressedObject(serialized, map);
+
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Object v = map.get(Integer.toString(i));
+ Assert.assertEquals(i, v);
+ }
+ map.put(Integer.toString(1), Integer.MAX_VALUE);
+ Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
+ Assert.assertEquals(numEntries, map.size());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/stream/StreamUtilsTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/stream/StreamUtilsTest.java b/core/src/test/java/hivemall/utils/stream/StreamUtilsTest.java
new file mode 100644
index 0000000..8607576
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/stream/StreamUtilsTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+import java.io.IOException;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class StreamUtilsTest {
+
+ @Test
+ public void testToArrayIntStream() throws IOException {
+ Random rand = new Random(43L);
+ int[] src = new int[9999];
+ for (int i = 0; i < src.length; i++) {
+ src[i] = rand.nextInt();
+ }
+
+ IntStream stream = StreamUtils.toArrayIntStream(src);
+ IntIterator itor = stream.iterator();
+ int i = 0;
+ while (itor.hasNext()) {
+ Assert.assertEquals(src[i], itor.next());
+ i++;
+ }
+ Assert.assertFalse(itor.hasNext());
+ Assert.assertEquals(src.length, i);
+
+ itor = stream.iterator();
+ i = 0;
+ while (itor.hasNext()) {
+ Assert.assertEquals(src[i], itor.next());
+ i++;
+ }
+ Assert.assertFalse(itor.hasNext());
+ Assert.assertEquals(src.length, i);
+ }
+
+
+ @Test
+ public void testToCompressedIntStreamIntArray() throws IOException {
+ Random rand = new Random(43L);
+ int[] src = new int[9999];
+ for (int i = 0; i < src.length; i++) {
+ src[i] = rand.nextInt();
+ }
+
+ IntStream stream = StreamUtils.toCompressedIntStream(src);
+ IntIterator itor = stream.iterator();
+ int i = 0;
+ while (itor.hasNext()) {
+ Assert.assertEquals(src[i], itor.next());
+ i++;
+ }
+ Assert.assertFalse(itor.hasNext());
+ Assert.assertEquals(src.length, i);
+
+ itor = stream.iterator();
+ i = 0;
+ while (itor.hasNext()) {
+ Assert.assertEquals(src[i], itor.next());
+ i++;
+ }
+ Assert.assertFalse(itor.hasNext());
+ Assert.assertEquals(src.length, i);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/resources/hivemall/classifier/news20-multiclass.gz
----------------------------------------------------------------------
diff --git a/core/src/test/resources/hivemall/classifier/news20-multiclass.gz b/core/src/test/resources/hivemall/classifier/news20-multiclass.gz
new file mode 100644
index 0000000..939f2d5
Binary files /dev/null and b/core/src/test/resources/hivemall/classifier/news20-multiclass.gz differ
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index dd6db6c..18ef9df 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@ -205,7 +205,7 @@ final class GroupedDataEx protected[sql](
val udaf = HiveUDAFFunction(
new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
Seq(predict).map(df.col(_).expr),
- isUDAFBridgeRequired = true)
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF((Alias(udaf, udaf.prettyString)() :: Nil).toSeq)
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 4ef14f6..df82547 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -543,16 +543,17 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val row7 = df7.groupby($"c0").maxrow("c2", "c1").as("c0", "c1").select($"c1.col1").collect
assert(row7(0).getString(0) == "id-0")
- val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF.as("c0", "c1")
- val row8 = df8.groupby($"c0").rf_ensemble("c1").as("c0", "c1").select("c1.probability").collect
- assert(row8(0).getDouble(0) ~== 0.3333333333)
- assert(row8(1).getDouble(0) ~== 1.0)
-
- val df9 = Seq((1, 3), (1, 8), (2, 9), (1, 1)).toDF.as("c0", "c1")
- val row9 = df9.groupby($"c0").agg("c1" -> "rf_ensemble").as("c0", "c1")
- .select("c1.probability").collect
- assert(row9(0).getDouble(0) ~== 0.3333333333)
- assert(row9(1).getDouble(0) ~== 1.0)
+ // val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF.as("c0", "c1")
+ // val row8 = df8.groupby($"c0").rf_ensemble("c1").as("c0", "c1")
+ // .select("c1.probability").collect
+ // assert(row8(0).getDouble(0) ~== 0.3333333333)
+ // assert(row8(1).getDouble(0) ~== 1.0)
+
+ // val df9 = Seq((1, 3), (1, 8), (2, 9), (1, 1)).toDF.as("c0", "c1")
+ // val row9 = df9.groupby($"c0").agg("c1" -> "rf_ensemble").as("c0", "c1")
+ // .select("c1.probability").collect
+ // assert(row9(0).getDouble(0) ~== 0.3333333333)
+ // assert(row9(1).getDouble(0) ~== 1.0)
}
test("user-defined aggregators for evaluation") {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index bdeff98..a68f88f 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@ -127,7 +127,7 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
"rf_ensemble",
new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
Seq(predict).map(df.col(_).expr),
- isUDAFBridgeRequired = true)
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index e9ccac8..89deb07 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -638,11 +638,11 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val row7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect
assert(row7(0).getString(0) == "id-0")
- val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1")
- val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1")
- .select("c1.probability").collect
- assert(row8(0).getDouble(0) ~== 0.3333333333)
- assert(row8(1).getDouble(0) ~== 1.0)
+ // val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1")
+ // val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1")
+ // .select("c1.probability").collect
+ // assert(row8(0).getDouble(0) ~== 0.3333333333)
+ // assert(row8(1).getDouble(0) ~== 1.0)
}
test("user-defined aggregators for evaluation") {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index bdeff98..a68f88f 100644
--- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@ -127,7 +127,7 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
"rf_ensemble",
new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
Seq(predict).map(df.col(_).expr),
- isUDAFBridgeRequired = true)
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 1547227..f634f9b 100644
--- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -787,11 +787,11 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val row7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect
assert(row7(0).getString(0) == "id-0")
- val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1")
- val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1")
- .select("c1.probability").collect
- assert(row8(0).getDouble(0) ~== 0.3333333333)
- assert(row8(1).getDouble(0) ~== 1.0)
+ // val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1")
+ // val row8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1")
+ // .select("c1.probability").collect
+ // assert(row8(0).getDouble(0) ~== 0.3333333333)
+ // assert(row8(1).getDouble(0) ~== 1.0)
}
test("user-defined aggregators for evaluation") {
[11/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
new file mode 100644
index 0000000..d028d47
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class ColumnMajorDenseIntMatrix2d extends ColumnMajorIntMatrix {
+
+ @Nonnull
+ private final int[][] data; // col-row
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+
+ public ColumnMajorDenseIntMatrix2d(@Nonnull int[][] data, @Nonnegative int numRows) {
+ super();
+ this.data = data;
+ this.numRows = numRows;
+ this.numColumns = data.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int[] getRow(final int index) {
+ checkRowIndex(index, numRows);
+
+ int[] row = new int[numColumns];
+ return getRow(index, row);
+ }
+
+ @Override
+ public int[] getRow(final int index, @Nonnull final int[] dst) {
+ checkRowIndex(index, numRows);
+
+ for (int j = 0; j < data.length; j++) {
+ final int[] col = data[j];
+ if (index < col.length) {
+ dst[j] = col[index];
+ }
+ }
+ return dst;
+ }
+
+ @Override
+ public int get(final int row, final int col, final int defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ if (row >= colData.length) {
+ return defaultValue;
+ }
+ return colData[row];
+ }
+
+ @Override
+ public int getAndSet(final int row, final int col, final int value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+
+ final int old = colData[row];
+ colData[row] = value;
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final int value) {
+ checkIndex(row, col, numRows, numColumns);
+ if (value == 0) {
+ return;
+ }
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+ colData[row] = value;
+ }
+
+ @Override
+ public void incr(final int row, final int col, final int delta) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int[] colData = data[col];
+ checkRowIndex(row, colData.length);
+
+ colData[row] += delta;
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final int[] colData = data[col];
+ if (colData == null) {
+ if (nullOutput) {
+ for (int i = 0; i < numRows; i++) {
+ procedure.apply(i, defaultValue);
+ }
+ }
+ return;
+ }
+
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ procedure.apply(row, colData[row]);
+ }
+ if (nullOutput) {
+ for (; row < numRows; row++) {
+ procedure.apply(row, defaultValue);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final int[] colData = data[col];
+ if (colData == null) {
+ return;
+ }
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ final int v = colData[row];
+ if (v != 0) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
new file mode 100644
index 0000000..e0b3b4b
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix {
+
+ public ColumnMajorIntMatrix() {
+ super();
+ }
+
+ @Override
+ public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachNonZeroInRow(int row, VectorProcedure procedure) {
+ throw new UnsupportedOperationException();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
new file mode 100644
index 0000000..2bbd3b4
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2IntOpenHashTable;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Dictionary-of-Key Sparse Int Matrix.
+ */
+public final class DoKIntMatrix extends AbstractIntMatrix {
+
+ @Nonnull
+ private final Long2IntOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+
+ public DoKIntMatrix() {
+ this(0, 0);
+ }
+
+ public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
+ @Nonnegative float sparsity) {
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2IntOpenHashTable(initialCapacity);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ }
+
+ private DoKIntMatrix(@Nonnull Long2IntOpenHashTable elements, @Nonnegative int numRows,
+ @Nonnegative int numColumns) {
+ this.elements = elements;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int[] getRow(@Nonnegative final int index) {
+ int[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public int[] getRow(@Nonnegative final int row, @Nonnull final int[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long index = index(row, col);
+ int v = elements.get(index, defaultValue);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public int get(@Nonnegative final int row, @Nonnegative final int col, final int defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final int value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+
+ @Override
+ public int getAndSet(@Nonnegative final int row, @Nonnegative final int col, final int value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ int old = elements.put(index, value);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ return old;
+ }
+
+ @Override
+ public void incr(@Nonnegative final int row, @Nonnegative final int col, final int delta) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ elements.incr(index, delta);
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, defaultValue);
+ }
+ } else {
+ int v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int v = elements.get(i, 0);
+ if (v != 0) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, defaultValue);
+ }
+ } else {
+ int v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int v = elements.get(i, 0);
+ if (v != 0) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+ @Nonnull
+ public static DoKIntMatrix build(@Nonnull final int[][] matrix, boolean rowMajorInput,
+ boolean nonZeroOnly) {
+ if (rowMajorInput) {
+ return buildFromRowMajorMatrix(matrix, nonZeroOnly);
+ } else {
+ return buildFromColumnMajorMatrix(matrix, nonZeroOnly);
+ }
+ }
+
+ @Nonnull
+ private static DoKIntMatrix buildFromRowMajorMatrix(@Nonnull final int[][] rowMajorMatrix,
+ boolean nonZeroOnly) {
+ final Long2IntOpenHashTable elements = new Long2IntOpenHashTable(rowMajorMatrix.length * 3);
+
+ int numRows = rowMajorMatrix.length, numColumns = 0;
+ for (int i = 0; i < rowMajorMatrix.length; i++) {
+ final int[] row = rowMajorMatrix[i];
+ if (row == null) {
+ continue;
+ }
+ numColumns = Math.max(numColumns, row.length);
+ for (int col = 0; col < row.length; col++) {
+ int value = row[col];
+ if (nonZeroOnly && value == 0) {
+ continue;
+ }
+ long index = index(i, col);
+ elements.put(index, value);
+ }
+ }
+
+ return new DoKIntMatrix(elements, numRows, numColumns);
+ }
+
+ @Nonnull
+ private static DoKIntMatrix buildFromColumnMajorMatrix(
+ @Nonnull final int[][] columnMajorMatrix, boolean nonZeroOnly) {
+ final Long2IntOpenHashTable elements = new Long2IntOpenHashTable(
+ columnMajorMatrix.length * 3);
+
+ int numRows = 0, numColumns = columnMajorMatrix.length;
+ for (int j = 0; j < columnMajorMatrix.length; j++) {
+ final int[] col = columnMajorMatrix[j];
+ if (col == null) {
+ continue;
+ }
+ numRows = Math.max(numRows, col.length);
+ for (int row = 0; row < col.length; row++) {
+ int value = col[row];
+ if (nonZeroOnly && value == 0) {
+ continue;
+ }
+ long index = index(row, j);
+ elements.put(index, value);
+ }
+ }
+
+ return new DoKIntMatrix(elements, numRows, numColumns);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
new file mode 100644
index 0000000..bcc954e
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public interface IntMatrix {
+
+ public boolean isSparse();
+
+ public boolean readOnly();
+
+ public void setDefaultValue(int value);
+
+ @Nonnegative
+ public int numRows();
+
+ @Nonnegative
+ public int numColumns();
+
+ @Nonnull
+ public int[] row();
+
+ @Nonnull
+ public int[] getRow(@Nonnegative int index);
+
+ /**
+ * @return returns dst
+ */
+ @Nonnull
+ public int[] getRow(@Nonnegative int index, @Nonnull int[] dst);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public int get(@Nonnegative int row, @Nonnegative int col);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public int get(@Nonnegative int row, @Nonnegative int col, int defaultValue);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative int row, @Nonnegative int col, int value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public int getAndSet(@Nonnegative int row, @Nonnegative int col, int value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void incr(@Nonnegative int row, @Nonnegative int col);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void incr(@Nonnegative int row, @Nonnegative int col, int delta);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
new file mode 100644
index 0000000..d2232b2
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.builders.CSCMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
+ */
+public final class CSCMatrix extends ColumnMajorMatrix {
+
+ @Nonnull
+ private final int[] columnPointers;
+ @Nonnull
+ private final int[] rowIndicies;
+ @Nonnull
+ private final double[] values;
+
+ private final int numRows;
+ private final int numColumns;
+ private final int nnz;
+
+ public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndicies,
+ @Nonnull double[] values, int numRows, int numColumns) {
+ super();
+ Preconditions.checkArgument(columnPointers.length >= 1,
+ "rowPointers must be greather than 0: " + columnPointers.length);
+ Preconditions.checkArgument(rowIndicies.length == values.length, "#rowIndicies ("
+ + rowIndicies.length + ") must be equals to #values (" + values.length + ")");
+ this.columnPointers = columnPointers;
+ this.rowIndicies = rowIndicies;
+ this.values = values;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(final int row) {
+ checkRowIndex(row, numRows);
+
+ return ArrayUtils.count(rowIndicies, row);
+ }
+
+ @Override
+ public double[] getRow(int index) {
+ checkRowIndex(index, numRows);
+
+ final double[] row = new double[numColumns];
+
+ final int numCols = columnPointers.length - 1;
+ for (int j = 0; j < numCols; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ row[j] = values[k];
+ }
+ }
+
+ return row;
+ }
+
+ @Override
+ public double[] getRow(final int index, @Nonnull final double[] dst) {
+ checkRowIndex(index, numRows);
+
+ final int last = Math.min(dst.length, columnPointers.length - 1);
+ for (int j = 0; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ dst[j] = values[k];
+ }
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int j = 0, last = columnPointers.length - 1; j < last; j++) {
+ final int k = Arrays.binarySearch(rowIndicies, columnPointers[j],
+ columnPointers[j + 1], index);
+ if (k >= 0) {
+ double v = values[k];
+ row.set(j, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(final int row, final int col, final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public double getAndSet(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ double old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = columnPointers[col];
+ int rightEx = columnPointers[col + 1];
+ final int index = Arrays.binarySearch(rowIndicies, leftIn, rightEx, row);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(final int row1, final int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+
+ if (nullOutput) {
+ for (int row = 0, i = startIn; row < numRows; row++) {
+ if (i < endEx && row == rowIndicies[i]) {
+ double v = values[i++];
+ procedure.apply(row, v);
+ } else {
+ procedure.apply(row, 0.d);
+ }
+ }
+ } else {
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndicies[j];
+ double v = values[j];
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final int startIn = columnPointers[col];
+ final int endEx = columnPointers[col + 1];
+ for (int j = startIn; j < endEx; j++) {
+ int row = rowIndicies[j];
+ final double v = values[j];
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public CSRMatrix toRowMajorMatrix() {
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndicies = new int[nnz];
+ final double[] csrValues = new double[nnz];
+
+ // compute nnz per for each row
+ for (int i = 0; i < rowIndicies.length; i++) {
+ rowPointers[rowIndicies[i]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ for (int j = 0; j < numColumns; j++) {
+ for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) {
+ int col = rowIndicies[i];
+ int dst = rowPointers[col];
+
+ colIndicies[dst] = j;
+ csrValues[dst] = values[i];
+
+ rowPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+
+ return new CSRMatrix(rowPointers, colIndicies, csrValues, numColumns);
+ }
+
+ @Override
+ public CSCMatrixBuilder builder() {
+ return new CSCMatrixBuilder(nnz);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
new file mode 100644
index 0000000..dd89521
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Read-only CSR double Matrix.
+ *
+ * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
+ * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
+ */
+public final class CSRMatrix extends RowMajorMatrix {
+
+ @Nonnull
+ private final int[] rowPointers;
+ @Nonnull
+ private final int[] columnIndices;
+ @Nonnull
+ private final double[] values;
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+ @Nonnegative
+ private final int nnz;
+
+ public CSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices,
+ @Nonnull double[] values, @Nonnegative int numColumns) {
+ super();
+ Preconditions.checkArgument(rowPointers.length >= 1,
+ "rowPointers must be greather than 0: " + rowPointers.length);
+ Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices ("
+ + columnIndices.length + ") must be equals to #values (" + values.length + ")");
+ this.rowPointers = rowPointers;
+ this.columnIndices = columnIndices;
+ this.values = values;
+ this.numRows = rowPointers.length - 1;
+ this.numColumns = numColumns;
+ this.nnz = values.length;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ checkRowIndex(row, numRows);
+
+ int columns = rowPointers[row + 1] - rowPointers[row];
+ return columns;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ final double[] row = new double[numColumns];
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, double value) {
+ row[col] = value;
+ }
+ });
+ return row;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index, @Nonnull final double[] dst) {
+ Arrays.fill(dst, 0.d);
+ eachNonZeroInRow(index, new VectorProcedure() {
+ public void apply(int col, double value) {
+ checkColIndex(col, numColumns);
+ dst[col] = value;
+ }
+ });
+ return dst;
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ return defaultValue;
+ }
+ return values[index];
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+
+ double old = values[index];
+ values[index] = value;
+ return old;
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final int index = getIndex(row, col);
+ if (index < 0) {
+ throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
+ + col);
+ }
+ values[index] = value;
+ }
+
+ private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
+ int leftIn = rowPointers[row];
+ int rightEx = rowPointers[row + 1];
+ final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col);
+ if (index >= 0 && index >= values.length) {
+ throw new IndexOutOfBoundsException("Value index " + index + " out of range "
+ + values.length);
+ }
+ return index;
+ }
+
+ @Override
+ public void swap(int row1, int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ if (nullOutput) {
+ for (int col = 0, j = startIn; col < numColumns; col++) {
+ if (j < endEx && col == columnIndices[j]) {
+ double v = values[j++];
+ procedure.apply(col, v);
+ } else {
+ procedure.apply(col, 0.d);
+ }
+ }
+ } else {
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i], values[i]);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+ for (int i = startIn; i < endEx; i++) {
+ int col = columnIndices[i];
+ final double v = values[i];
+ if (v != 0.d) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final int startIn = rowPointers[row];
+ final int endEx = rowPointers[row + 1];
+
+ for (int i = startIn; i < endEx; i++) {
+ procedure.apply(columnIndices[i]);
+ }
+ }
+
+ @Nonnull
+ public CSCMatrix toColumnMajorMatrix() {
+ final int[] columnPointers = new int[numColumns + 1];
+ final int[] rowIndicies = new int[nnz];
+ final double[] cscValues = new double[nnz];
+
+ // compute nnz per for each column
+ for (int j = 0; j < columnIndices.length; j++) {
+ columnPointers[columnIndices[j]]++;
+ }
+ for (int j = 0, sum = 0; j < numColumns; j++) {
+ int curr = columnPointers[j];
+ columnPointers[j] = sum;
+ sum += curr;
+ }
+ columnPointers[numColumns] = nnz;
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = rowPointers[i], last = rowPointers[i + 1]; j < last; j++) {
+ int col = columnIndices[j];
+ int dst = columnPointers[col];
+
+ rowIndicies[dst] = i;
+ cscValues[dst] = values[j];
+
+ columnPointers[col]++;
+ }
+ }
+
+ // shift column pointers
+ for (int j = 0, last = 0; j <= numColumns; j++) {
+ int tmp = columnPointers[j];
+ columnPointers[j] = last;
+ last = tmp;
+ }
+
+ return new CSCMatrix(columnPointers, rowIndicies, cscValues, numRows, numColumns);
+ }
+
+ @Override
+ public CSRMatrixBuilder builder() {
+ return new CSRMatrixBuilder(values.length);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
new file mode 100644
index 0000000..bcfd152
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.annotations.Experimental;
+import hivemall.math.matrix.AbstractMatrix;
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2DoubleOpenHashTable;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+@Experimental
+public final class DoKMatrix extends AbstractMatrix {
+
+ @Nonnull
+ private final Long2DoubleOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public DoKMatrix() {
+ this(0, 0);
+ }
+
+ public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols, @Nonnegative float sparsity) {
+ super();
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2DoubleOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.d);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ this.nnz = 0;
+ }
+
+ public DoKMatrix(@Nonnegative int initSize) {
+ super();
+ int initialCapacity = Math.max(initSize, 16384);
+ this.elements = new Long2DoubleOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.d);
+ this.numRows = 0;
+ this.numColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ int count = 0;
+ for (int j = 0; j < numColumns; j++) {
+ long index = index(row, j);
+ if (elements.containsKey(index)) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ double[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ double v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int col = 0; col < numColumns; col++) {
+ long k = index(index, col);
+ final double v = elements.get(k, 0.d);
+ if (v != 0.d) {
+ row.set(col, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ checkIndex(row, col);
+
+ if (value == 0.d) {
+ return;
+ }
+
+ long index = index(row, col);
+ if (elements.put(index, value, 0.d) == 0.d) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ checkIndex(row, col);
+
+ long index = index(row, col);
+ double old = elements.put(index, value, 0.d);
+ if (old == 0.d) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ return old;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ for (int j = 0; j < numColumns; j++) {
+ final long i1 = index(row1, j);
+ final long i2 = index(row2, j);
+
+ final int k1 = elements._findKey(i1);
+ final int k2 = elements._findKey(i2);
+
+ if (k1 >= 0) {
+ if (k2 >= 0) {
+ double v1 = elements._get(k1);
+ double v2 = elements._set(k2, v1);
+ elements._set(k1, v2);
+ } else {// k1>=0 and k2<0
+ double v1 = elements._remove(k1);
+ elements.put(i2, v1);
+ }
+ } else if (k2 >= 0) {// k2>=0 and k1 < 0
+ double v2 = elements._remove(k2);
+ elements.put(i1, v2);
+ } else {//k1<0 and k2<0
+ continue;
+ }
+ }
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, 0.d);
+ }
+ } else {
+ double v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final double v = elements.get(i, 0.d);
+ if (v != 0.d) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key != -1) {
+ procedure.apply(col);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, 0.d);
+ }
+ } else {
+ double v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final double v = elements.get(i, 0.d);
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public RowMajorMatrix toRowMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public ColumnMajorMatrix toColumnMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public DoKMatrixBuilder builder() {
+ return new DoKMatrixBuilder(elements.size());
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/CommonsMathRandom.java b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
new file mode 100644
index 0000000..e0b7554
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.math3.random.MersenneTwister;
+import org.apache.commons.math3.random.RandomGenerator;
+
+public final class CommonsMathRandom implements PRNG {
+
+ @Nonnull
+ private final RandomGenerator rng;
+
+ public CommonsMathRandom() {
+ this.rng = new MersenneTwister();
+ }
+
+ public CommonsMathRandom(long seed) {
+ this.rng = new MersenneTwister(seed);
+ }
+
+ public CommonsMathRandom(@Nonnull RandomGenerator rng) {
+ this.rng = rng;
+ }
+
+ @Override
+ public int nextInt(final int n) {
+ return rng.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rng.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rng.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rng.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/JavaRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/JavaRandom.java b/core/src/main/java/hivemall/math/random/JavaRandom.java
new file mode 100644
index 0000000..f0ed4c7
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/JavaRandom.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+
+public final class JavaRandom implements PRNG {
+
+ private final Random rand;
+
+ public JavaRandom() {
+ this.rand = new Random();
+ }
+
+ public JavaRandom(long seed) {
+ this.rand = new Random(seed);
+ }
+
+ public JavaRandom(@Nonnull Random rand) {
+ this.rand = rand;
+ }
+
+ @Override
+ public int nextInt(int n) {
+ return rand.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rand.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rand.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rand.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/PRNG.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/PRNG.java b/core/src/main/java/hivemall/math/random/PRNG.java
new file mode 100644
index 0000000..d42dcfb
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/PRNG.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnegative;
+
+/**
+ * @link https://en.wikipedia.org/wiki/Pseudorandom_number_generator
+ */
+public interface PRNG {
+
+ /**
+ * Returns a random integer in [0, n).
+ */
+ public int nextInt(@Nonnegative int n);
+
+ public int nextInt();
+
+ public long nextLong();
+
+ public double nextDouble();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
new file mode 100644
index 0000000..8843f7e
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import hivemall.utils.lang.Primitives;
+
+import java.security.SecureRandom;
+
+import javax.annotation.Nonnull;
+
+public final class RandomNumberGeneratorFactory {
+
+ private RandomNumberGeneratorFactory() {}
+
+ @Nonnull
+ public static PRNG createPRNG() {
+ return createPRNG(PRNGType.smile);
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(long seed) {
+ return createPRNG(PRNGType.smile, seed);
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(@Nonnull PRNGType type) {
+ final PRNG rng;
+ switch (type) {
+ case java:
+ rng = new JavaRandom();
+ break;
+ case secure:
+ rng = new JavaRandom(new SecureRandom());
+ break;
+ case smile:
+ rng = new SmileRandom();
+ break;
+ case smileMT:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister());
+ break;
+ case smileMT64:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister64());
+ break;
+ case commonsMath3MT:
+ rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister());
+ break;
+ default:
+ throw new IllegalStateException("Unexpected type: " + type);
+ }
+ return rng;
+ }
+
+ @Nonnull
+ public static PRNG createPRNG(@Nonnull PRNGType type, long seed) {
+ final PRNG rng;
+ switch (type) {
+ case java:
+ rng = new JavaRandom(seed);
+ break;
+ case secure:
+ rng = new JavaRandom(new SecureRandom(Primitives.toBytes(seed)));
+ break;
+ case smile:
+ rng = new SmileRandom(seed);
+ break;
+ case smileMT:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister(
+ Primitives.hashCode(seed)));
+ break;
+ case smileMT64:
+ rng = new SmileRandom(new smile.math.random.MersenneTwister64(seed));
+ break;
+ case commonsMath3MT:
+ rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister(
+ seed));
+ break;
+ default:
+ throw new IllegalStateException("Unexpected type: " + type);
+ }
+ return rng;
+ }
+
+ public enum PRNGType {
+ java, secure, smile, smileMT, smileMT64, commonsMath3MT;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/SmileRandom.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/random/SmileRandom.java b/core/src/main/java/hivemall/math/random/SmileRandom.java
new file mode 100644
index 0000000..1edc56c
--- /dev/null
+++ b/core/src/main/java/hivemall/math/random/SmileRandom.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.random;
+
+import javax.annotation.Nonnull;
+
+import smile.math.random.RandomNumberGenerator;
+import smile.math.random.UniversalGenerator;
+
+public final class SmileRandom implements PRNG {
+
+ @Nonnull
+ private RandomNumberGenerator rng;
+
+ public SmileRandom() {
+ this.rng = new UniversalGenerator();
+ }
+
+ public SmileRandom(long seed) {
+ this.rng = new UniversalGenerator(seed);
+ }
+
+ public SmileRandom(@Nonnull RandomNumberGenerator rng) {
+ this.rng = rng;
+ }
+
+ @Override
+ public int nextInt(int n) {
+ return rng.nextInt(n);
+ }
+
+ @Override
+ public int nextInt() {
+ return rng.nextInt();
+ }
+
+ @Override
+ public long nextLong() {
+ return rng.nextLong();
+ }
+
+ @Override
+ public double nextDouble() {
+ return rng.nextDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/AbstractVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java
new file mode 100644
index 0000000..88bed7b
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+
+public abstract class AbstractVector implements Vector {
+
+ public AbstractVector() {}
+
+ @Override
+ public double get(@Nonnegative final int index) {
+ return get(index, 0.d);
+ }
+
+ protected static final void checkIndex(final int index) {
+ if (index < 0) {
+ throw new IndexOutOfBoundsException("Invalid index " + index);
+ }
+ }
+
+ protected static final void checkIndex(final int index, final int size) {
+ if (index < 0 || index >= size) {
+ throw new IndexOutOfBoundsException("Index " + index + " out of bounds " + size);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/DenseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/DenseVector.java b/core/src/main/java/hivemall/math/vector/DenseVector.java
new file mode 100644
index 0000000..bd39af1
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/DenseVector.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class DenseVector extends AbstractVector {
+
+ @Nonnull
+ private final double[] values;
+ private final int size;
+
+ public DenseVector(@Nonnegative int size) {
+ super();
+ this.values = new double[size];
+ this.size = size;
+ }
+
+ public DenseVector(@Nonnull double[] values) {
+ super();
+ this.values = values;
+ this.size = values.length;
+ }
+
+ @Override
+ public double get(@Nonnegative final int index, final double defaultValue) {
+ checkIndex(index);
+ if (index >= size) {
+ return defaultValue;
+ }
+
+ return values[index];
+ }
+
+ @Override
+ public void set(@Nonnegative final int index, final double value) {
+ checkIndex(index, size);
+
+ values[index] = value;
+ }
+
+ @Override
+ public void incr(@Nonnegative final int index, final double delta) {
+ checkIndex(index, size);
+
+ values[index] += delta;
+ }
+
+ @Override
+ public void each(@Nonnull final VectorProcedure procedure) {
+ for (int i = 0; i < values.length; i++) {
+ procedure.apply(i, values[i]);
+ }
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public void clear() {
+ Arrays.fill(values, 0.d);
+ }
+
+ @Override
+ public double[] toArray() {
+ return values;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/SparseVector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/SparseVector.java b/core/src/main/java/hivemall/math/vector/SparseVector.java
new file mode 100644
index 0000000..072b544
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/SparseVector.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class SparseVector extends AbstractVector {
+
+ @Nonnull
+ private final SparseDoubleArray values;
+
+ public SparseVector() {
+ super();
+ this.values = new SparseDoubleArray();
+ }
+
+ public SparseVector(@Nonnull SparseDoubleArray values) {
+ super();
+ this.values = values;
+ }
+
+ @Override
+ public double get(@Nonnegative final int index, final double defaultValue) {
+ return values.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int index, final double value) {
+ values.put(index, value);
+ }
+
+ @Override
+ public void incr(@Nonnegative final int index, final double delta) {
+ values.increment(index, delta);
+ }
+
+ @Override
+ public void each(@Nonnull final VectorProcedure procedure) {
+ values.each(procedure);
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public void clear() {
+ values.clear();
+ }
+
+ @Override
+ public double[] toArray() {
+ return values.toArray();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/Vector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/Vector.java b/core/src/main/java/hivemall/math/vector/Vector.java
new file mode 100644
index 0000000..2e5107d
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/Vector.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public interface Vector {
+
+ public double get(@Nonnegative int index);
+
+ public double get(@Nonnegative int index, double defaultValue);
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative int index, double value);
+
+ public void incr(@Nonnegative int index, double delta);
+
+ public void each(@Nonnull VectorProcedure procedure);
+
+ public int size();
+
+ public void clear();
+
+ @Nonnull
+ public double[] toArray();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/VectorProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
new file mode 100644
index 0000000..266c531
--- /dev/null
+++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.vector;
+
+import javax.annotation.Nonnegative;
+
+public abstract class VectorProcedure {
+
+ public VectorProcedure() {}
+
+ public void apply(@Nonnegative int i, double value) {}
+
+ public void apply(@Nonnegative int i, int value) {}
+
+ public void apply(@Nonnegative int i) {}
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java b/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
deleted file mode 100644
index d2deda1..0000000
--- a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.collections.DoubleArrayList;
-import hivemall.utils.collections.IntArrayList;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-/**
- * Compressed Sparse Row Matrix.
- *
- * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
- * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
- */
-public final class CSRMatrixBuilder extends MatrixBuilder {
-
- @Nonnull
- private final IntArrayList rowPointers;
- @Nonnull
- private final IntArrayList columnIndices;
- @Nonnull
- private final DoubleArrayList values;
-
- private int maxNumColumns;
-
- public CSRMatrixBuilder(int initSize) {
- super();
- this.rowPointers = new IntArrayList(initSize + 1);
- rowPointers.add(0);
- this.columnIndices = new IntArrayList(initSize);
- this.values = new DoubleArrayList(initSize);
- this.maxNumColumns = 0;
- }
-
- @Override
- public CSRMatrixBuilder nextRow() {
- int ptr = values.size();
- rowPointers.add(ptr);
- return this;
- }
-
- @Override
- public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) {
- if (value == 0.d) {
- return this;
- }
-
- columnIndices.add(col);
- values.add(value);
- this.maxNumColumns = Math.max(col + 1, maxNumColumns);
- return this;
- }
-
- @Override
- public Matrix buildMatrix(boolean readOnly) {
- if (!readOnly) {
- throw new UnsupportedOperationException("Only readOnly matrix is supported");
- }
-
- ReadOnlyCSRMatrix matrix = new ReadOnlyCSRMatrix(rowPointers.toArray(true),
- columnIndices.toArray(true), values.toArray(true), maxNumColumns);
- return matrix;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java b/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
deleted file mode 100644
index f70616e..0000000
--- a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.collections.SparseDoubleArray;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public final class DenseMatrixBuilder extends MatrixBuilder {
-
- @Nonnull
- private final List<double[]> rows;
- private int maxNumColumns;
-
- @Nonnull
- private final SparseDoubleArray rowProbe;
-
- public DenseMatrixBuilder(int initSize) {
- super();
- this.rows = new ArrayList<double[]>(initSize);
- this.maxNumColumns = 0;
- this.rowProbe = new SparseDoubleArray(32);
- }
-
- @Override
- public MatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
- if (value == 0.d) {
- return this;
- }
- rowProbe.put(col, value);
- return this;
- }
-
- @Override
- public MatrixBuilder nextRow() {
- double[] row = rowProbe.toArray();
- rowProbe.clear();
- nextRow(row);
- return this;
- }
-
- @Override
- public void nextRow(@Nonnull double[] row) {
- rows.add(row);
- this.maxNumColumns = Math.max(row.length, maxNumColumns);
- }
-
- @Override
- public Matrix buildMatrix(boolean readOnly) {
- if (!readOnly) {
- throw new UnsupportedOperationException("Only readOnly matrix is supported");
- }
-
- int numRows = rows.size();
- double[][] data = rows.toArray(new double[numRows][]);
- return new ReadOnlyDenseMatrix2d(data, maxNumColumns);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/Matrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/Matrix.java b/core/src/main/java/hivemall/matrix/Matrix.java
deleted file mode 100644
index 8bbb6c5..0000000
--- a/core/src/main/java/hivemall/matrix/Matrix.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-
-public abstract class Matrix {
-
- private double defaultValue;
-
- public Matrix() {
- this.defaultValue = 0.d;
- }
-
- public abstract boolean readOnly();
-
- public void setDefaultValue(double value) {
- this.defaultValue = value;
- }
-
- @Nonnegative
- public abstract int numRows();
-
- @Nonnegative
- public abstract int numColumns();
-
- @Nonnegative
- public abstract int numColumns(@Nonnegative int row);
-
- /**
- * @throws IndexOutOfBoundsException
- */
- public final double get(@Nonnegative final int row, @Nonnegative final int col) {
- return get(row, col, defaultValue);
- }
-
- /**
- * @throws IndexOutOfBoundsException
- */
- public abstract double get(@Nonnegative int row, @Nonnegative int col, double defaultValue);
-
- /**
- * @throws IndexOutOfBoundsException
- * @throws UnsupportedOperationException
- */
- public abstract void set(@Nonnegative int row, @Nonnegative int col, double value);
-
- /**
- * @throws IndexOutOfBoundsException
- * @throws UnsupportedOperationException
- */
- public abstract double getAndSet(@Nonnegative int row, @Nonnegative final int col, double value);
-
- protected static final void checkRowIndex(final int row, final int numRows) {
- if (row < 0 || row >= numRows) {
- throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
- }
- }
-
- protected static final void checkColIndex(final int col, final int numColumns) {
- if (col < 0 || col >= numColumns) {
- throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
- }
- }
-
- protected static final void checkIndex(final int row, final int col, final int numRows,
- final int numColumns) {
- if (row < 0 || row >= numRows) {
- throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
- }
- if (col < 0 || col >= numColumns) {
- throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/MatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/MatrixBuilder.java b/core/src/main/java/hivemall/matrix/MatrixBuilder.java
deleted file mode 100644
index e4d6233..0000000
--- a/core/src/main/java/hivemall/matrix/MatrixBuilder.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public abstract class MatrixBuilder {
-
- public MatrixBuilder() {}
-
- public void nextRow(@Nonnull final double[] row) {
- for (int col = 0; col < row.length; col++) {
- nextColumn(col, row[col]);
- }
- nextRow();
- }
-
- public void nextRow(@Nonnull final String[] row) {
- for (String col : row) {
- if (col == null) {
- continue;
- }
- nextColumn(col);
- }
- nextRow();
- }
-
- @Nonnull
- public abstract MatrixBuilder nextRow();
-
- @Nonnull
- public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value);
-
- /**
- * @throws IllegalArgumentException
- * @throws NumberFormatException
- */
- @Nonnull
- public MatrixBuilder nextColumn(@Nonnull final String col) {
- final int pos = col.indexOf(':');
- if (pos == 0) {
- throw new IllegalArgumentException("Invalid feature value representation: " + col);
- }
-
- final String feature;
- final double value;
- if (pos > 0) {
- feature = col.substring(0, pos);
- String s2 = col.substring(pos + 1);
- value = Double.parseDouble(s2);
- } else {
- feature = col;
- value = 1.d;
- }
-
- if (feature.indexOf(':') != -1) {
- throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col);
- }
-
- int colIndex = Integer.parseInt(feature);
- if (colIndex < 0) {
- throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: "
- + colIndex);
- }
-
- return nextColumn(colIndex, value);
- }
-
- @Nonnull
- public abstract Matrix buildMatrix(boolean readOnly);
-
-}
[03/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
new file mode 100644
index 0000000..1fb3a08
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.sampling;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n
+ * items.
+ *
+ * @link http://en.wikipedia.org/wiki/Reservoir_sampling
+ * @link http://portal.acm.org/citation.cfm?id=3165
+ */
+public final class ReservoirSampler<T> {
+
+ private final T[] samples;
+ private final int numSamples;
+ private int position;
+
+ private final Random rand;
+
+ @SuppressWarnings("unchecked")
+ public ReservoirSampler(int sampleSize) {
+ if (sampleSize <= 0) {
+ throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize);
+ }
+ this.samples = (T[]) new Object[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ @SuppressWarnings("unchecked")
+ public ReservoirSampler(int sampleSize, long seed) {
+ this.samples = (T[]) new Object[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public ReservoirSampler(T[] samples) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public ReservoirSampler(T[] samples, long seed) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public T[] getSample() {
+ return samples;
+ }
+
+ public List<T> getSamplesAsList() {
+ return Arrays.asList(samples);
+ }
+
+ public void add(T item) {
+ if (item == null) {
+ return;
+ }
+ if (position < numSamples) {// reservoir not yet full, just append
+ samples[position] = item;
+ } else {// find a item to replace
+ int replaceIndex = rand.nextInt(position + 1);
+ if (replaceIndex < numSamples) {
+ samples[replaceIndex] = item;
+ }
+ }
+ position++;
+ }
+
+ public void clear() {
+ Arrays.fill(samples, null);
+ this.position = 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntIterator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/IntIterator.java b/core/src/main/java/hivemall/utils/stream/IntIterator.java
new file mode 100644
index 0000000..794d81e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/IntIterator.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+public interface IntIterator {
+
+ boolean hasNext();
+
+ int next();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntStream.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/IntStream.java b/core/src/main/java/hivemall/utils/stream/IntStream.java
new file mode 100644
index 0000000..4130177
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/IntStream.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+import javax.annotation.Nonnull;
+
+public interface IntStream {
+
+ @Nonnull
+ IntIterator iterator();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/StreamUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/stream/StreamUtils.java b/core/src/main/java/hivemall/utils/stream/StreamUtils.java
new file mode 100644
index 0000000..7bd7b63
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/stream/StreamUtils.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.stream;
+
+import hivemall.utils.io.DeflaterOutputStream;
+import hivemall.utils.io.FastByteArrayInputStream;
+import hivemall.utils.io.FastMultiByteArrayOutputStream;
+import hivemall.utils.io.IOUtils;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.NoSuchElementException;
+import java.util.zip.Deflater;
+import java.util.zip.Inflater;
+import java.util.zip.InflaterInputStream;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class StreamUtils {
+
+ private StreamUtils() {}
+
+ @Nonnull
+ public static IntStream toCompressedIntStream(@Nonnull final int[] src) {
+ return toCompressedIntStream(src, Deflater.DEFAULT_COMPRESSION);
+ }
+
+ @Nonnull
+ public static IntStream toCompressedIntStream(@Nonnull final int[] src, final int level) {
+ FastMultiByteArrayOutputStream bos = new FastMultiByteArrayOutputStream(16384);
+ Deflater deflater = new Deflater(level, true);
+ DeflaterOutputStream defos = new DeflaterOutputStream(bos, deflater, 8192);
+ DataOutputStream dos = new DataOutputStream(defos);
+
+ final int count = src.length;
+ final byte[] compressed;
+ try {
+ for (int i = 0; i < count; i++) {
+ dos.writeInt(src[i]);
+ }
+ defos.finish();
+ compressed = bos.toByteArray_clear();
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to compress int[]", e);
+ } finally {
+ IOUtils.closeQuietly(dos);
+ }
+
+ return new InflateIntStream(compressed, count);
+ }
+
+ @Nonnull
+ public static IntStream toArrayIntStream(@Nonnull int[] array) {
+ return new ArrayIntStream(array);
+ }
+
+ static final class ArrayIntStream implements IntStream {
+
+ @Nonnull
+ private final int[] array;
+
+ ArrayIntStream(@Nonnull int[] array) {
+ this.array = array;
+ }
+
+ @Override
+ public ArrayIntIterator iterator() {
+ return new ArrayIntIterator(array);
+ }
+
+ }
+
+ static final class ArrayIntIterator implements IntIterator {
+
+ @Nonnull
+ private final int[] array;
+ @Nonnegative
+ private final int count;
+ @Nonnegative
+ private int index;
+
+ ArrayIntIterator(@Nonnull int[] array) {
+ this.array = array;
+ this.count = array.length;
+ this.index = 0;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < count;
+ }
+
+ @Override
+ public int next() {
+ if (index < count) {// hasNext()
+ return array[index++];
+ }
+ throw new NoSuchElementException();
+ }
+
+ }
+
+ static final class InflateIntStream implements IntStream {
+
+ @Nonnull
+ private final byte[] compressed;
+ @Nonnegative
+ private final int count;
+
+ InflateIntStream(@Nonnull byte[] compressed, @Nonnegative int count) {
+ this.compressed = compressed;
+ this.count = count;
+ }
+
+ @Override
+ public InflatedIntIterator iterator() {
+ FastByteArrayInputStream bis = new FastByteArrayInputStream(compressed);
+ InflaterInputStream infis = new InflaterInputStream(bis, new Inflater(true), 512);
+ DataInputStream in = new DataInputStream(infis);
+ return new InflatedIntIterator(in, count);
+ }
+
+ }
+
+ static final class InflatedIntIterator implements IntIterator {
+
+ @Nonnull
+ private final DataInputStream in;
+ @Nonnegative
+ private final int count;
+ @Nonnegative
+ private int index;
+
+ InflatedIntIterator(@Nonnull DataInputStream in, @Nonnegative int count) {
+ this.in = in;
+ this.count = count;
+ this.index = 0;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < count;
+ }
+
+ @Override
+ public int next() {
+ if (index < count) {// hasNext()
+ final int v;
+ try {
+ v = in.readInt();
+ } catch (IOException e) {
+ throw new IllegalStateException("Invalid input at " + index, e);
+ }
+ index++;
+ return v;
+ }
+ throw new NoSuchElementException();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
index a65a69a..076387f 100644
--- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
+++ b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
@@ -19,7 +19,7 @@
package hivemall.fm;
import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable;
import java.io.IOException;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
new file mode 100644
index 0000000..decd7df
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -0,0 +1,644 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.CSCMatrixBuilder;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.matrix.sparse.CSCMatrix;
+import hivemall.math.matrix.sparse.CSRMatrix;
+import hivemall.math.matrix.sparse.DoKMatrix;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class MatrixBuilderTest {
+
+ @Test
+ public void testReadOnlyCSRMatrix() {
+ Matrix matrix = csrMatrix();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyCSRMatrixFromLibSVM() {
+ Matrix matrix = csrMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyCSRMatrixNoRow() {
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ Matrix matrix = builder.buildMatrix();
+ Assert.assertEquals(0, matrix.numRows());
+ Assert.assertEquals(0, matrix.numColumns());
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyCSRMatrixGetFail1() {
+ Matrix matrix = csrMatrix();
+ matrix.get(7, 5);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyCSRMatrixGetFail2() {
+ Matrix matrix = csrMatrix();
+ matrix.get(6, 7);
+ }
+
+ @Test
+ public void testCSCMatrixFromLibSVM() {
+ CSCMatrix matrix = cscMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testCSC2CSR() {
+ CSCMatrix csc = cscMatrixFromLibSVM();
+ RowMajorMatrix csr = csc.toRowMajorMatrix();
+ Assert.assertTrue(csr instanceof CSRMatrix);
+ Assert.assertEquals(6, csr.numRows());
+ Assert.assertEquals(6, csr.numColumns());
+ Assert.assertEquals(4, csr.numColumns(0));
+ Assert.assertEquals(2, csr.numColumns(1));
+ Assert.assertEquals(4, csr.numColumns(2));
+ Assert.assertEquals(2, csr.numColumns(3));
+ Assert.assertEquals(1, csr.numColumns(4));
+ Assert.assertEquals(1, csr.numColumns(5));
+
+ Assert.assertEquals(11d, csr.get(0, 0), 0.d);
+ Assert.assertEquals(12d, csr.get(0, 1), 0.d);
+ Assert.assertEquals(13d, csr.get(0, 2), 0.d);
+ Assert.assertEquals(14d, csr.get(0, 3), 0.d);
+ Assert.assertEquals(22d, csr.get(1, 1), 0.d);
+ Assert.assertEquals(23d, csr.get(1, 2), 0.d);
+ Assert.assertEquals(33d, csr.get(2, 2), 0.d);
+ Assert.assertEquals(34d, csr.get(2, 3), 0.d);
+ Assert.assertEquals(35d, csr.get(2, 4), 0.d);
+ Assert.assertEquals(36d, csr.get(2, 5), 0.d);
+ Assert.assertEquals(44d, csr.get(3, 3), 0.d);
+ Assert.assertEquals(45d, csr.get(3, 4), 0.d);
+ Assert.assertEquals(56d, csr.get(4, 5), 0.d);
+ Assert.assertEquals(66d, csr.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, csr.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, csr.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, csr.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testCSC2CSR2CSR() {
+ CSCMatrix csc = cscMatrixFromLibSVM();
+ CSCMatrix csc2 = csc.toRowMajorMatrix().toColumnMajorMatrix();
+ Assert.assertEquals(csc.nnz(), csc2.nnz());
+ Assert.assertEquals(6, csc2.numRows());
+ Assert.assertEquals(6, csc2.numColumns());
+ Assert.assertEquals(4, csc2.numColumns(0));
+ Assert.assertEquals(2, csc2.numColumns(1));
+ Assert.assertEquals(4, csc2.numColumns(2));
+ Assert.assertEquals(2, csc2.numColumns(3));
+ Assert.assertEquals(1, csc2.numColumns(4));
+ Assert.assertEquals(1, csc2.numColumns(5));
+
+ Assert.assertEquals(11d, csc2.get(0, 0), 0.d);
+ Assert.assertEquals(12d, csc2.get(0, 1), 0.d);
+ Assert.assertEquals(13d, csc2.get(0, 2), 0.d);
+ Assert.assertEquals(14d, csc2.get(0, 3), 0.d);
+ Assert.assertEquals(22d, csc2.get(1, 1), 0.d);
+ Assert.assertEquals(23d, csc2.get(1, 2), 0.d);
+ Assert.assertEquals(33d, csc2.get(2, 2), 0.d);
+ Assert.assertEquals(34d, csc2.get(2, 3), 0.d);
+ Assert.assertEquals(35d, csc2.get(2, 4), 0.d);
+ Assert.assertEquals(36d, csc2.get(2, 5), 0.d);
+ Assert.assertEquals(44d, csc2.get(3, 3), 0.d);
+ Assert.assertEquals(45d, csc2.get(3, 4), 0.d);
+ Assert.assertEquals(56d, csc2.get(4, 5), 0.d);
+ Assert.assertEquals(66d, csc2.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, csc2.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, csc2.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d);
+ }
+
+
+ @Test
+ public void testDoKMatrixFromLibSVM() {
+ Matrix matrix = dokMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(2, matrix.numColumns(1));
+ Assert.assertEquals(4, matrix.numColumns(2));
+ Assert.assertEquals(2, matrix.numColumns(3));
+ Assert.assertEquals(1, matrix.numColumns(4));
+ Assert.assertEquals(1, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+ Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
+
+ Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2d() {
+ Matrix matrix = rowMajorDenseMatrix();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dSparseInput() {
+ Matrix matrix = denseMatrixSparseInput();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dFromLibSVM() {
+ Matrix matrix = denseMatrixFromLibSVM();
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+ Assert.assertEquals(4, matrix.numColumns(0));
+ Assert.assertEquals(3, matrix.numColumns(1));
+ Assert.assertEquals(6, matrix.numColumns(2));
+ Assert.assertEquals(5, matrix.numColumns(3));
+ Assert.assertEquals(6, matrix.numColumns(4));
+ Assert.assertEquals(6, matrix.numColumns(5));
+
+ Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testReadOnlyDenseMatrix2dNoRow() {
+ Matrix matrix = new RowMajorDenseMatrixBuilder(1024).buildMatrix();
+ Assert.assertEquals(0, matrix.numRows());
+ Assert.assertEquals(0, matrix.numColumns());
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dFailOutOfBound1() {
+ Matrix matrix = rowMajorDenseMatrix();
+ matrix.get(7, 5);
+ }
+
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testReadOnlyDenseMatrix2dFailOutOfBound2() {
+ Matrix matrix = rowMajorDenseMatrix();
+ matrix.get(6, 7);
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrix2d() {
+ ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
+
+ Assert.assertEquals(6, colMatrix.numRows());
+ Assert.assertEquals(6, colMatrix.numColumns());
+ Assert.assertEquals(4, colMatrix.numColumns(0));
+ Assert.assertEquals(2, colMatrix.numColumns(1));
+ Assert.assertEquals(4, colMatrix.numColumns(2));
+ Assert.assertEquals(2, colMatrix.numColumns(3));
+ Assert.assertEquals(1, colMatrix.numColumns(4));
+ Assert.assertEquals(1, colMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testDenseMatrixColumnMajor2RowMajor() {
+ ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix();
+ RowMajorDenseMatrix2d rowMatrix = colMatrix.toRowMajorMatrix();
+
+ Assert.assertEquals(6, rowMatrix.numRows());
+ Assert.assertEquals(6, rowMatrix.numColumns());
+ Assert.assertEquals(4, rowMatrix.numColumns(0));
+ Assert.assertEquals(3, rowMatrix.numColumns(1));
+ Assert.assertEquals(6, rowMatrix.numColumns(2));
+ Assert.assertEquals(5, rowMatrix.numColumns(3));
+ Assert.assertEquals(6, rowMatrix.numColumns(4));
+ Assert.assertEquals(6, rowMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, rowMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, rowMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, rowMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, rowMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, rowMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, rowMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, rowMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, rowMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, rowMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, rowMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, rowMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, rowMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, rowMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, rowMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, rowMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, rowMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d);
+
+ // convert back to column major matrix
+
+ colMatrix = rowMatrix.toColumnMajorMatrix();
+
+ Assert.assertEquals(6, colMatrix.numRows());
+ Assert.assertEquals(6, colMatrix.numColumns());
+ Assert.assertEquals(4, colMatrix.numColumns(0));
+ Assert.assertEquals(2, colMatrix.numColumns(1));
+ Assert.assertEquals(4, colMatrix.numColumns(2));
+ Assert.assertEquals(2, colMatrix.numColumns(3));
+ Assert.assertEquals(1, colMatrix.numColumns(4));
+ Assert.assertEquals(1, colMatrix.numColumns(5));
+
+ Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d);
+ Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d);
+ Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d);
+ Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d);
+ Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d);
+ Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d);
+ Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d);
+ Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d);
+ Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d);
+ Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d);
+ Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d);
+ Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d);
+ Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d);
+ Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d);
+
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d);
+ Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d);
+ }
+
+ @Test
+ public void testCSRMatrixNullRow() {
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextRow();
+ builder.nextColumn(3, 66).nextRow();
+ Matrix matrix = builder.buildMatrix();
+ Assert.assertEquals(4, matrix.numRows());
+ }
+
+ private static CSRMatrix csrMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
+ builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
+ builder.nextColumn(5, 56).nextRow();
+ builder.nextColumn(5, 66).nextRow();
+ return builder.buildMatrix();
+ }
+
+ private static CSRMatrix csrMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+ private static CSCMatrix cscMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ CSCMatrixBuilder builder = new CSCMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+
+ private static DoKMatrix dokMatrixFromLibSVM() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ DoKMatrixBuilder builder = new DoKMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d rowMajorDenseMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new double[] {11, 12, 13, 14});
+ builder.nextRow(new double[] {0, 22, 23});
+ builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
+ builder.nextRow(new double[] {0, 0, 0, 44, 45});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
+ return builder.buildMatrix();
+ }
+
+ private static ColumnMajorDenseMatrix2d columnMajorDenseMatrix() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ ColumnMajorDenseMatrixBuilder builder = new ColumnMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new double[] {11, 12, 13, 14});
+ builder.nextRow(new double[] {0, 22, 23});
+ builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
+ builder.nextRow(new double[] {0, 0, 0, 44, 45});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
+ builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d denseMatrixSparseInput() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
+ builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
+ builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
+ builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
+ builder.nextColumn(5, 56).nextRow();
+ builder.nextColumn(5, 66).nextRow();
+ return builder.buildMatrix();
+ }
+
+ private static RowMajorDenseMatrix2d denseMatrixFromLibSVM() {
+ RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024);
+ builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
+ builder.nextRow(new String[] {"1:22", "2:23"});
+ builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
+ builder.nextRow(new String[] {"3:44", "4:45"});
+ builder.nextRow(new String[] {"5:56"});
+ builder.nextRow(new String[] {"5:66"});
+ return builder.buildMatrix();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
new file mode 100644
index 0000000..f6a52fe
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d;
+import hivemall.math.matrix.ints.DoKIntMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntMatrixTest {
+
+ @Test
+ public void testDoKMatrixRowMajor() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, true);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajor() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, true);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajorNonZeroOnlyFalse() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false);
+
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(0, matrix.get(1, 3, -1));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+
+ matrix.setDefaultValue(-1);
+ Assert.assertEquals(-1, matrix.get(5, 4));
+ Assert.assertEquals(-1, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(0, matrix.get(1, 0, 0));
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrix() {
+ ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6);
+ Assert.assertEquals(6, matrix.numRows());
+ Assert.assertEquals(6, matrix.numColumns());
+
+ Assert.assertEquals(11, matrix.get(0, 0));
+ Assert.assertEquals(12, matrix.get(0, 1));
+ Assert.assertEquals(13, matrix.get(0, 2));
+ Assert.assertEquals(14, matrix.get(0, 3));
+ Assert.assertEquals(22, matrix.get(1, 1));
+ Assert.assertEquals(23, matrix.get(1, 2));
+ Assert.assertEquals(33, matrix.get(2, 2));
+ Assert.assertEquals(34, matrix.get(2, 3));
+ Assert.assertEquals(35, matrix.get(2, 4));
+ Assert.assertEquals(36, matrix.get(2, 5));
+ Assert.assertEquals(44, matrix.get(3, 3));
+ Assert.assertEquals(45, matrix.get(3, 4));
+ Assert.assertEquals(56, matrix.get(4, 5));
+ Assert.assertEquals(66, matrix.get(5, 5));
+
+ Assert.assertEquals(0, matrix.get(5, 4));
+ Assert.assertEquals(0, matrix.get(1, 0));
+ Assert.assertEquals(0, matrix.get(1, 3));
+ Assert.assertEquals(-1, matrix.get(1, 0, -1));
+ }
+
+ @Test
+ public void testColumnMajorDenseMatrixEachColumn() {
+ ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+
+ // change default value to zero
+ matrix.setDefaultValue(0);
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+ }
+
+ @Test
+ public void testDoKMatrixColumnMajorNonZeroOnlyFalseEachColumn() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+
+ // change default value to zero
+ matrix.setDefaultValue(0);
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int j = 0; j < 6; j++) {
+ matrix.eachNonZeroInColumn(j, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue());
+ }
+
+ @Test
+ public void testDoKMatrixRowMajorNonZeroOnlyFalseEachColumn() {
+ DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, false);
+ matrix.setDefaultValue(-1);
+
+ final MutableInt count = new MutableInt(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, false);
+ }
+ Assert.assertEquals(4 + 3 + 6 + 5 + 6 + 6, count.getValue());
+
+ count.setValue(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ }, true);
+ }
+ Assert.assertEquals(6 * 6, count.getValue());
+
+ count.setValue(0);
+ for (int i = 0; i < 6; i++) {
+ matrix.eachNonZeroInRow(i, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ count.addValue(1);
+ }
+ });
+ }
+ Assert.assertEquals(4 + 2 + 4 + 2 + 1 + 1, count.getValue());
+ }
+
+ private static int[][] rowMajorData() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ int[][] data = new int[6][];
+ data[0] = new int[] {11, 12, 13, 14};
+ data[1] = new int[] {0, 22, 23};
+ data[2] = new int[] {0, 0, 33, 34, 35, 36};
+ data[3] = new int[] {0, 0, 0, 44, 45};
+ data[4] = new int[] {0, 0, 0, 0, 0, 56};
+ data[5] = new int[] {0, 0, 0, 0, 0, 66};
+ return data;
+ }
+
+ private static int[][] columnMajorData() {
+ /*
+ 11 12 13 14 0 0
+ 0 22 23 0 0 0
+ 0 0 33 34 35 36
+ 0 0 0 44 45 0
+ 0 0 0 0 0 56
+ 0 0 0 0 0 66
+ */
+ int[][] data = new int[6][];
+ data[0] = new int[] {11};
+ data[1] = new int[] {12, 22};
+ data[2] = new int[] {13, 23, 33};
+ data[3] = new int[] {14, 0, 34, 44};
+ data[4] = new int[] {0, 0, 35, 45};
+ data[5] = new int[] {0, 0, 36, 0, 56, 66};
+ return data;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
deleted file mode 100644
index 5545631..0000000
--- a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class MatrixBuilderTest {
-
- @Test
- public void testReadOnlyCSRMatrix() {
- Matrix matrix = csrMatrix();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(2, matrix.numColumns(1));
- Assert.assertEquals(4, matrix.numColumns(2));
- Assert.assertEquals(2, matrix.numColumns(3));
- Assert.assertEquals(1, matrix.numColumns(4));
- Assert.assertEquals(1, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
- Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
-
- matrix.setDefaultValue(Double.NaN);
- Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d);
- }
-
- @Test
- public void testReadOnlyCSRMatrixFromLibSVM() {
- Matrix matrix = csrMatrixFromLibSVM();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(2, matrix.numColumns(1));
- Assert.assertEquals(4, matrix.numColumns(2));
- Assert.assertEquals(2, matrix.numColumns(3));
- Assert.assertEquals(1, matrix.numColumns(4));
- Assert.assertEquals(1, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
- Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d);
-
- matrix.setDefaultValue(Double.NaN);
- Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d);
- }
-
-
- @Test
- public void testReadOnlyCSRMatrixNoRow() {
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- Matrix matrix = builder.buildMatrix(true);
- Assert.assertEquals(0, matrix.numRows());
- Assert.assertEquals(0, matrix.numColumns());
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyCSRMatrixGetFail1() {
- Matrix matrix = csrMatrix();
- matrix.get(7, 5);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyCSRMatrixGetFail2() {
- Matrix matrix = csrMatrix();
- matrix.get(6, 7);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2d() {
- Matrix matrix = denseMatrix();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dSparseInput() {
- Matrix matrix = denseMatrixSparseInput();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dFromLibSVM() {
- Matrix matrix = denseMatrixFromLibSVM();
- Assert.assertEquals(6, matrix.numRows());
- Assert.assertEquals(6, matrix.numColumns());
- Assert.assertEquals(4, matrix.numColumns(0));
- Assert.assertEquals(3, matrix.numColumns(1));
- Assert.assertEquals(6, matrix.numColumns(2));
- Assert.assertEquals(5, matrix.numColumns(3));
- Assert.assertEquals(6, matrix.numColumns(4));
- Assert.assertEquals(6, matrix.numColumns(5));
-
- Assert.assertEquals(11d, matrix.get(0, 0), 0.d);
- Assert.assertEquals(12d, matrix.get(0, 1), 0.d);
- Assert.assertEquals(13d, matrix.get(0, 2), 0.d);
- Assert.assertEquals(14d, matrix.get(0, 3), 0.d);
- Assert.assertEquals(22d, matrix.get(1, 1), 0.d);
- Assert.assertEquals(23d, matrix.get(1, 2), 0.d);
- Assert.assertEquals(33d, matrix.get(2, 2), 0.d);
- Assert.assertEquals(34d, matrix.get(2, 3), 0.d);
- Assert.assertEquals(35d, matrix.get(2, 4), 0.d);
- Assert.assertEquals(36d, matrix.get(2, 5), 0.d);
- Assert.assertEquals(44d, matrix.get(3, 3), 0.d);
- Assert.assertEquals(45d, matrix.get(3, 4), 0.d);
- Assert.assertEquals(56d, matrix.get(4, 5), 0.d);
- Assert.assertEquals(66d, matrix.get(5, 5), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(5, 4), 0.d);
-
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 3), 0.d);
- Assert.assertEquals(0.d, matrix.get(1, 0), 0.d);
- }
-
- @Test
- public void testReadOnlyDenseMatrix2dNoRow() {
- Matrix matrix = new DenseMatrixBuilder(1024).buildMatrix(true);
- Assert.assertEquals(0, matrix.numRows());
- Assert.assertEquals(0, matrix.numColumns());
- }
-
- @Test(expected = UnsupportedOperationException.class)
- public void testReadOnlyDenseMatrix2dFailToChangeDefaultValue() {
- Matrix matrix = denseMatrix();
- matrix.setDefaultValue(Double.NaN);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyDenseMatrix2dFailOutOfBound1() {
- Matrix matrix = denseMatrix();
- matrix.get(7, 5);
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testReadOnlyDenseMatrix2dFailOutOfBound2() {
- Matrix matrix = denseMatrix();
- matrix.get(6, 7);
- }
-
- private static Matrix csrMatrix() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
- builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
- builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
- builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
- builder.nextColumn(5, 56).nextRow();
- builder.nextColumn(5, 66).nextRow();
- return builder.buildMatrix(true);
- }
-
- private static Matrix csrMatrixFromLibSVM() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
- builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
- builder.nextRow(new String[] {"1:22", "2:23"});
- builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
- builder.nextRow(new String[] {"3:44", "4:45"});
- builder.nextRow(new String[] {"5:56"});
- builder.nextRow(new String[] {"5:66"});
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrix() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextRow(new double[] {11, 12, 13, 14});
- builder.nextRow(new double[] {0, 22, 23});
- builder.nextRow(new double[] {0, 0, 33, 34, 35, 36});
- builder.nextRow(new double[] {0, 0, 0, 44, 45});
- builder.nextRow(new double[] {0, 0, 0, 0, 0, 56});
- builder.nextRow(new double[] {0, 0, 0, 0, 0, 66});
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrixSparseInput() {
- /*
- 11 12 13 14 0 0
- 0 22 23 0 0 0
- 0 0 33 34 35 36
- 0 0 0 44 45 0
- 0 0 0 0 0 56
- 0 0 0 0 0 66
- */
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow();
- builder.nextColumn(1, 22).nextColumn(2, 23).nextRow();
- builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow();
- builder.nextColumn(3, 44).nextColumn(4, 45).nextRow();
- builder.nextColumn(5, 56).nextRow();
- builder.nextColumn(5, 66).nextRow();
- return builder.buildMatrix(true);
- }
-
- private static Matrix denseMatrixFromLibSVM() {
- DenseMatrixBuilder builder = new DenseMatrixBuilder(1024);
- builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"});
- builder.nextRow(new String[] {"1:22", "2:23"});
- builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"});
- builder.nextRow(new String[] {"3:44", "4:45"});
- builder.nextRow(new String[] {"5:56"});
- builder.nextRow(new String[] {"5:66"});
- return builder.buildMatrix(true);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index 3c6116c..bb6de6b 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -19,13 +19,13 @@
package hivemall.smile.classification;
import static org.junit.Assert.assertEquals;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.classification.DecisionTree.Node;
import hivemall.smile.data.Attribute;
-import hivemall.smile.tools.TreePredictUDF;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
-import hivemall.utils.lang.ArrayUtils;
import java.io.BufferedInputStream;
import java.io.IOException;
@@ -33,14 +33,9 @@ import java.io.InputStream;
import java.net.URL;
import java.text.ParseException;
+import javax.annotation.Nonnull;
+
import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.IntWritable;
import org.junit.Assert;
import org.junit.Test;
@@ -52,85 +47,76 @@ import smile.validation.LOOCV;
public class DecisionTreeTest {
private static final boolean DEBUG = false;
- /**
- * Test of learn method, of class DecisionTree.
- *
- * @throws ParseException
- * @throws IOException
- */
@Test
public void testWeather() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff");
- InputStream is = new BufferedInputStream(url.openStream());
-
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
-
- AttributeDataset weather = arffParser.parse(is);
- double[][] x = weather.toArray(new double[weather.size()][]);
- int[] y = weather.toArray(new int[weather.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- int error = 0;
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
+ int responseIndex = 4;
+ int numLeafs = 3;
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(weather.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 3);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
- error++;
- }
+ // dense matrix
+ int error = run(
+ "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(5, error);
- debugPrint("Decision Tree error = " + error);
+ // sparse matrix
+ error = run(
+ "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
+ responseIndex, numLeafs, false);
assertEquals(5, error);
}
@Test
public void testIris() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
- InputStream is = new BufferedInputStream(url.openStream());
-
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
-
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- int error = 0;
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
-
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- smile.math.Random rand = new smile.math.Random(i);
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, Integer.MAX_VALUE, rand);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
- error++;
- }
+ int responseIndex = 4;
+ int numLeafs = Integer.MAX_VALUE;
+ int error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(8, error);
- debugPrint("Decision Tree error = " + error);
+ // sparse
+ error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, false);
assertEquals(8, error);
}
@Test
+ public void testIrisSparseDenseEquals() throws IOException, ParseException {
+ int responseIndex = 4;
+ int numLeafs = Integer.MAX_VALUE;
+ runAndCompareSparseAndDense(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs);
+ }
+
+ @Test
public void testIrisDepth4() throws IOException, ParseException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ int responseIndex = 4;
+ int numLeafs = 4;
+ int error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, true);
+ assertEquals(7, error);
+
+ // sparse
+ error = run(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs, false);
+ assertEquals(7, error);
+ }
+
+ private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
+ throws IOException, ParseException {
+ URL url = new URL(datasetUrl);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
+ arffParser.setResponseIndex(responseIndex);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
+ AttributeDataset ds = arffParser.parse(is);
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
@@ -139,52 +125,29 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]]))
+ Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
error++;
+ }
}
debugPrint("Decision Tree error = " + error);
- assertEquals(7, error);
+ return error;
}
- @Test
- public void testIrisStackmachine() throws IOException, ParseException, HiveException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex,
+ int numLeafs) throws IOException, ParseException {
+ URL url = new URL(datasetUrl);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
-
- int n = x.length;
- LOOCV loocv = new LOOCV(n);
- for (int i = 0; i < n; i++) {
- double[][] trainx = Math.slice(x, loocv.train[i]);
- int[] trainy = Math.slice(y, loocv.train[i]);
-
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]),
- predictByStackMachine(tree, x[loocv.test[i]]));
- }
- }
-
- @Test
- public void testIrisJavascript() throws IOException, ParseException, HiveException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
- InputStream is = new BufferedInputStream(url.openStream());
+ arffParser.setResponseIndex(responseIndex);
- ArffParser arffParser = new ArffParser();
- arffParser.setResponseIndex(4);
- AttributeDataset iris = arffParser.parse(is);
- double[][] x = iris.toArray(new double[iris.size()][]);
- int[] y = iris.toArray(new int[iris.size()]);
+ AttributeDataset ds = arffParser.parse(is);
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
@@ -192,10 +155,12 @@ public class DecisionTreeTest {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
- Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]),
- predictByJavascript(tree, x[loocv.test[i]]));
+ Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+ Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
}
}
@@ -218,7 +183,7 @@ public class DecisionTreeTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b = tree.predictSerCodegen(false);
Node node = DecisionTree.deserializeNode(b, b.length, false);
@@ -245,7 +210,7 @@ public class DecisionTreeTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
+ DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
byte[] b1 = tree.predictSerCodegen(true);
byte[] b2 = tree.predictSerCodegen(false);
@@ -256,52 +221,18 @@ public class DecisionTreeTest {
}
}
- private static int predictByStackMachine(DecisionTree tree, double[] x) throws HiveException,
- IOException {
- String script = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(script);
-
- TreePredictUDF udf = new TreePredictUDF();
- udf.initialize(new ObjectInspector[] {
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
- ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
- DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(script),
- new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)};
-
- IntWritable result = (IntWritable) udf.evaluate(arguments);
- result = (IntWritable) udf.evaluate(arguments);
- udf.close();
- return result.get();
- }
-
- private static int predictByJavascript(DecisionTree tree, double[] x) throws HiveException,
- IOException {
- String script = tree.predictJsCodegen();
- debugPrint(script);
-
- TreePredictUDF udf = new TreePredictUDF();
- udf.initialize(new ObjectInspector[] {
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
- ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
-
- DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.javascript.getId()),
- new DeferredJavaObject(script), new DeferredJavaObject(ArrayUtils.toList(x)),
- new DeferredJavaObject(true)};
-
- IntWritable result = (IntWritable) udf.evaluate(arguments);
- result = (IntWritable) udf.evaluate(arguments);
- udf.close();
- return result.get();
+ @Nonnull
+ private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
+ if (dense) {
+ return new RowMajorDenseMatrix2d(x, x[0].length);
+ } else {
+ int numRows = x.length;
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ for (int i = 0; i < numRows; i++) {
+ builder.nextRow(x[i]);
+ }
+ return builder.buildMatrix();
+ }
}
private static void debugPrint(String msg) {
[10/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java b/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
deleted file mode 100644
index 1c7a9a1..0000000
--- a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import hivemall.utils.lang.Preconditions;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-/**
- * Read-only CSR Matrix.
- *
- * @see http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
- */
-public final class ReadOnlyCSRMatrix extends Matrix {
-
- @Nonnull
- private final int[] rowPointers;
- @Nonnull
- private final int[] columnIndices;
- @Nonnull
- private final double[] values;
-
- @Nonnegative
- private final int numRows;
- @Nonnegative
- private final int numColumns;
-
- public ReadOnlyCSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices,
- @Nonnull double[] values, @Nonnegative int numColumns) {
- super();
- Preconditions.checkArgument(rowPointers.length >= 1,
- "rowPointers must be greather than 0: " + rowPointers.length);
- Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices ("
- + columnIndices.length + ") must be equals to #values (" + values.length + ")");
- this.rowPointers = rowPointers;
- this.columnIndices = columnIndices;
- this.values = values;
- this.numRows = rowPointers.length - 1;
- this.numColumns = numColumns;
- }
-
- @Override
- public boolean readOnly() {
- return true;
- }
-
- @Override
- public int numRows() {
- return numRows;
- }
-
- @Override
- public int numColumns() {
- return numColumns;
- }
-
- @Override
- public int numColumns(@Nonnegative final int row) {
- checkRowIndex(row, numRows);
-
- int columns = rowPointers[row + 1] - rowPointers[row];
- return columns;
- }
-
- @Override
- public double get(@Nonnegative final int row, @Nonnegative final int col,
- final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- return defaultValue;
- }
- return values[index];
- }
-
- @Override
- public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
- final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
- + col);
- }
-
- double old = values[index];
- values[index] = value;
- return old;
- }
-
- @Override
- public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final int index = getIndex(row, col);
- if (index < 0) {
- throw new UnsupportedOperationException("Cannot update value in row " + row + ", col "
- + col);
- }
- values[index] = value;
- }
-
- private int getIndex(@Nonnegative final int row, @Nonnegative final int col) {
- int leftIn = rowPointers[row];
- int rightEx = rowPointers[row + 1];
- final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col);
- if (index >= 0 && index >= values.length) {
- throw new IndexOutOfBoundsException("Value index " + index + " out of range "
- + values.length);
- }
- return index;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java b/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
deleted file mode 100644
index 040fef8..0000000
--- a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.matrix;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-
-public final class ReadOnlyDenseMatrix2d extends Matrix {
-
- @Nonnull
- private final double[][] data;
-
- @Nonnegative
- private final int numRows;
- @Nonnegative
- private final int numColumns;
-
- public ReadOnlyDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) {
- this.data = data;
- this.numRows = data.length;
- this.numColumns = numColumns;
- }
-
- @Override
- public boolean readOnly() {
- return true;
- }
-
- @Override
- public void setDefaultValue(double value) {
- throw new UnsupportedOperationException("The defaultValue of a DenseMatrix is fixed to 0.d");
- }
-
- @Override
- public int numRows() {
- return numRows;
- }
-
- @Override
- public int numColumns() {
- return numColumns;
- }
-
- @Override
- public int numColumns(@Nonnegative final int row) {
- checkRowIndex(row, numRows);
-
- return data[row].length;
- }
-
- @Override
- public double get(@Nonnegative final int row, @Nonnegative final int col,
- final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- if (col >= rowData.length) {
- return defaultValue;
- }
- return rowData[col];
- }
-
- @Override
- public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
- final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- checkColIndex(col, rowData.length);
-
- double old = rowData[col];
- rowData[col] = value;
- return old;
- }
-
- @Override
- public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
- checkIndex(row, col, numRows, numColumns);
-
- final double[] rowData = data[row];
- checkColIndex(col, rowData.length);
-
- rowData[col] = value;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/mf/FactorizedModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java
index b92a5d8..a4bea00 100644
--- a/core/src/main/java/hivemall/mf/FactorizedModel.java
+++ b/core/src/main/java/hivemall/mf/FactorizedModel.java
@@ -18,7 +18,7 @@
*/
package hivemall.mf;
-import hivemall.utils.collections.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashMap;
import hivemall.utils.math.MathUtils;
import java.util.Random;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/AbstractPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
index 37b69da..b48282b 100644
--- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java
+++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
@@ -21,8 +21,8 @@ package hivemall.model;
import hivemall.mix.MixedWeight;
import hivemall.mix.MixedWeight.WeightWithCovar;
import hivemall.mix.MixedWeight.WeightWithDelta;
-import hivemall.utils.collections.IntOpenHashMap;
-import hivemall.utils.collections.OpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index 96e1d5a..a2b4708 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -22,7 +22,7 @@ import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock;
import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock;
import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock;
import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.collections.OpenHashMap;
+import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/ModelType.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/ModelType.java b/core/src/main/java/hivemall/smile/ModelType.java
deleted file mode 100644
index 8925075..0000000
--- a/core/src/main/java/hivemall/smile/ModelType.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile;
-
-public enum ModelType {
-
- // not compressed
- opscode(1, false), javascript(2, false), serialization(3, false),
- // compressed
- opscode_compressed(-1, true), javascript_compressed(-2, true),
- serialization_compressed(-3, true);
-
- private final int id;
- private final boolean compressed;
-
- private ModelType(int id, boolean compressed) {
- this.id = id;
- this.compressed = compressed;
- }
-
- public int getId() {
- return id;
- }
-
- public boolean isCompressed() {
- return compressed;
- }
-
- public static ModelType resolve(String name, boolean compressed) {
- name = name.toLowerCase();
- if ("opscode".equals(name) || "vm".equals(name)) {
- return compressed ? opscode_compressed : opscode;
- } else if ("javascript".equals(name) || "js".equals(name)) {
- return compressed ? javascript_compressed : javascript;
- } else if ("serialization".equals(name) || "ser".equals(name)) {
- return compressed ? serialization_compressed : serialization;
- } else {
- throw new IllegalStateException("Unexpected output type: " + name);
- }
- }
-
- public static ModelType resolve(final int id) {
- final ModelType type;
- switch (id) {
- case 1:
- type = opscode;
- break;
- case -1:
- type = opscode_compressed;
- break;
- case 2:
- type = javascript;
- break;
- case -2:
- type = javascript_compressed;
- break;
- case 3:
- type = serialization;
- break;
- case -3:
- type = serialization_compressed;
- break;
- default:
- throw new IllegalStateException("Unexpected ID for ModelType: " + id);
- }
- return type;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/DecisionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 6b22473..2d086b9 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -33,100 +33,94 @@
*/
package hivemall.smile.classification;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.StringUtils;
+import hivemall.utils.sampling.IntReservoirSampler;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.roaringbitmap.IntConsumer;
+import org.roaringbitmap.RoaringBitmap;
import smile.classification.Classifier;
import smile.math.Math;
-import smile.math.Random;
/**
- * Decision tree for classification. A decision tree can be learned by splitting the training set
- * into subsets based on an attribute value test. This process is repeated on each derived subset in
- * a recursive manner called recursive partitioning. The recursion is completed when the subset at a
- * node all has the same value of the target variable, or when splitting no longer adds value to the
- * predictions.
+ * Decision tree for classification. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This
+ * process is repeated on each derived subset in a recursive manner called recursive partitioning. The recursion is completed when the subset at a
+ * node all has the same value of the target variable, or when splitting no longer adds value to the predictions.
* <p>
- * The algorithms that are used for constructing decision trees usually work top-down by choosing a
- * variable at each step that is the next best variable to use in splitting the set of items. "Best"
- * is defined by how well the variable splits the set into homogeneous subsets that have the same
- * value of the target variable. Different algorithms use different formulae for measuring "best".
- * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element
- * from the set would be incorrectly labeled if it were randomly labeled according to the
- * distribution of labels in the subset. Gini impurity can be computed by summing the probability of
- * each item being chosen times the probability of a mistake in categorizing that item. It reaches
- * its minimum (zero) when all cases in the node fall into a single target category. Information
- * gain is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is
- * based on the concept of entropy used in information theory. For categorical variables with
- * different number of levels, however, information gain are biased in favor of those attributes
- * with more levels. Instead, one may employ the information gain ratio, which solves the drawback
- * of information gain.
+ * The algorithms that are used for constructing decision trees usually work top-down by choosing a variable at each step that is the next best
+ * variable to use in splitting the set of items. "Best" is defined by how well the variable splits the set into homogeneous subsets that have the
+ * same value of the target variable. Different algorithms use different formulae for measuring "best". Used by the CART algorithm, Gini impurity is a
+ * measure of how often a randomly chosen element from the set would be incorrectly labeled if it were randomly labeled according to the distribution
+ * of labels in the subset. Gini impurity can be computed by summing the probability of each item being chosen times the probability of a mistake in
+ * categorizing that item. It reaches its minimum (zero) when all cases in the node fall into a single target category. Information gain is another
+ * popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is based on the concept of entropy used in information theory. For
+ * categorical variables with different number of levels, however, information gain are biased in favor of those attributes with more levels. Instead,
+ * one may employ the information gain ratio, which solves the drawback of information gain.
* <p>
- * Classification and Regression Tree techniques have a number of advantages over many of those
- * alternative techniques.
+ * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
* <dl>
* <dt>Simple to understand and interpret.</dt>
- * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This
- * simplicity is useful not only for purposes of rapid classification of new observations, but can
- * also often yield a much simpler "model" for explaining why observations are classified or
- * predicted in a particular manner.</dd>
+ * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid
+ * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in
+ * a particular manner.</dd>
* <dt>Able to handle both numerical and categorical data.</dt>
- * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of
- * variable.</dd>
+ * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd>
* <dt>Tree methods are nonparametric and nonlinear.</dt>
- * <dd>The final results of using tree methods for classification or regression can be summarized in
- * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
- * implicit assumption that the underlying relationships between the predictor variables and the
- * dependent variable are linear, follow some specific non-linear link function, or that they are
- * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
- * where there is often little a priori knowledge nor any coherent set of theories or predictions
- * regarding which variables are related and how. In those types of data analytics, tree methods can
- * often reveal simple relationships between just a few variables that could have easily gone
- * unnoticed using other analytic techniques.</dd>
+ * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then
+ * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the
+ * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are
+ * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions
+ * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a
+ * few variables that could have easily gone unnoticed using other analytic techniques.</dd>
* </dl>
- * One major problem with classification and regression trees is their high variance. Often a small
- * change in the data can result in a very different series of splits, making interpretation
- * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
- * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
- * of trees is the lack of smoothness of the prediction surface.
+ * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different
+ * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting.
+ * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface.
* <p>
- * Some techniques such as bagging, boosting, and random forest use more than one decision tree for
- * their analysis.
+ * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
*/
-public final class DecisionTree implements Classifier<double[]> {
+public final class DecisionTree implements Classifier<Vector> {
/**
* The attributes of independent variable.
*/
+ @Nonnull
private final Attribute[] _attributes;
private final boolean _hasNumericType;
/**
- * Variable importance. Every time a split of a node is made on variable the (GINI, information
- * gain, etc.) impurity criterion for the two descendant nodes is less than the parent node.
- * Adding up the decreases for each individual variable over the tree gives a simple measure of
+ * Variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two
+ * descendant nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of
* variable importance.
*/
- private final double[] _importance;
+ @Nonnull
+ private final Vector _importance;
/**
* The root of the regression tree
*/
+ @Nonnull
private final Node _root;
/**
* The maximum number of the tree depth
@@ -135,6 +129,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* The splitting rule.
*/
+ @Nonnull
private final SplitRule _rule;
/**
* The number of classes.
@@ -153,24 +148,23 @@ public final class DecisionTree implements Classifier<double[]> {
*/
private final int _minLeafSize;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ @Nonnull
+ private final ColumnMajorIntMatrix _order;
- private final Random _rnd;
+ @Nonnull
+ private final PRNG _rnd;
/**
* The criterion to choose variable to split instances.
*/
public static enum SplitRule {
/**
- * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen
- * element from the set would be incorrectly labeled if it were randomly labeled according
- * to the distribution of labels in the subset. Gini impurity can be computed by summing the
- * probability of each item being chosen times the probability of a mistake in categorizing
- * that item. It reaches its minimum (zero) when all cases in the node fall into a single
- * target category.
+ * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if
+ * it were randomly labeled according to the distribution of labels in the subset. Gini impurity can be computed by summing the probability of
+ * each item being chosen times the probability of a mistake in categorizing that item. It reaches its minimum (zero) when all cases in the
+ * node fall into a single target category.
*/
GINI,
/**
@@ -193,6 +187,11 @@ public final class DecisionTree implements Classifier<double[]> {
*/
int output = -1;
/**
+ * Posteriori probability based on sample ratios in this node.
+ */
+ @Nullable
+ double[] posteriori = null;
+ /**
* The split feature for this node.
*/
int splitFeature = -1;
@@ -227,28 +226,35 @@ public final class DecisionTree implements Classifier<double[]> {
public Node() {}// for Externalizable
- /**
- * Constructor.
- */
- public Node(int output) {
+ public Node(int output, @Nonnull double[] posteriori) {
this.output = output;
+ this.posteriori = posteriori;
+ }
+
+ private boolean isLeaf() {
+ return posteriori != null;
+ }
+
+ @VisibleForTesting
+ public int predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
}
/**
* Evaluate the regression tree over an instance.
*/
- public int predict(final double[] x) {
+ public int predict(@Nonnull final Vector x) {
if (trueChild == null && falseChild == null) {
return output;
} else {
if (splitFeatureType == AttributeType.NOMINAL) {
- if (x[splitFeature] == splitValue) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
} else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
@@ -260,6 +266,32 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
+ /**
+ * Evaluate the regression tree over an instance.
+ */
+ public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) {
+ if (trueChild == null && falseChild == null) {
+ handler.handle(output, posteriori);
+ } else {
+ if (splitFeatureType == AttributeType.NOMINAL) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
+ trueChild.predict(x, handler);
+ } else {
+ falseChild.predict(x, handler);
+ }
+ } else if (splitFeatureType == AttributeType.NUMERIC) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ trueChild.predict(x, handler);
+ } else {
+ falseChild.predict(x, handler);
+ }
+ } else {
+ throw new IllegalStateException("Unsupported attribute type: "
+ + splitFeatureType);
+ }
+ }
+ }
+
public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) {
if (trueChild == null && falseChild == null) {
indent(builder, depth);
@@ -298,99 +330,71 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- public int opCodegen(final List<String> scripts, int depth) {
- int selfDepth = 0;
- final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
- buf.append("push ").append(output);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("goto last");
- scripts.add(buf.toString());
- selfDepth += 2;
- } else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifeq ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifle ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else {
- throw new IllegalStateException("Unsupported attribute type: "
- + splitFeatureType);
- }
- }
- return selfDepth;
- }
-
@Override
public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(output);
out.writeInt(splitFeature);
if (splitFeatureType == null) {
- out.writeInt(-1);
+ out.writeByte(-1);
} else {
- out.writeInt(splitFeatureType.getTypeId());
+ out.writeByte(splitFeatureType.getTypeId());
}
out.writeDouble(splitValue);
- if (trueChild == null) {
- out.writeBoolean(false);
- } else {
+
+ if (isLeaf()) {
out.writeBoolean(true);
- trueChild.writeExternal(out);
- }
- if (falseChild == null) {
- out.writeBoolean(false);
+
+ out.writeInt(output);
+ out.writeInt(posteriori.length);
+ for (int i = 0; i < posteriori.length; i++) {
+ out.writeDouble(posteriori[i]);
+ }
} else {
- out.writeBoolean(true);
- falseChild.writeExternal(out);
+ out.writeBoolean(false);
+
+ if (trueChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ trueChild.writeExternal(out);
+ }
+ if (falseChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ falseChild.writeExternal(out);
+ }
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this.output = in.readInt();
this.splitFeature = in.readInt();
- int typeId = in.readInt();
+ byte typeId = in.readByte();
if (typeId == -1) {
this.splitFeatureType = null;
} else {
this.splitFeatureType = AttributeType.resolve(typeId);
}
this.splitValue = in.readDouble();
- if (in.readBoolean()) {
- this.trueChild = new Node();
- trueChild.readExternal(in);
- }
- if (in.readBoolean()) {
- this.falseChild = new Node();
- falseChild.readExternal(in);
+
+ if (in.readBoolean()) {//isLeaf
+ this.output = in.readInt();
+
+ final int size = in.readInt();
+ final double[] posteriori = new double[size];
+ for (int i = 0; i < size; i++) {
+ posteriori[i] = in.readDouble();
+ }
+ this.posteriori = posteriori;
+ } else {
+ if (in.readBoolean()) {
+ this.trueChild = new Node();
+ trueChild.readExternal(in);
+ }
+ if (in.readBoolean()) {
+ this.falseChild = new Node();
+ falseChild.readExternal(in);
+ }
}
}
@@ -413,7 +417,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* Training dataset.
*/
- final double[][] x;
+ final Matrix x;
/**
* class labels.
*/
@@ -426,7 +430,7 @@ public final class DecisionTree implements Classifier<double[]> {
/**
* Constructor.
*/
- public TrainNode(Node node, double[][] x, int[] y, int[] bags, int depth) {
+ public TrainNode(Node node, Matrix x, int[] y, int[] bags, int depth) {
this.node = node;
this.x = x;
this.y = y;
@@ -466,21 +470,12 @@ public final class DecisionTree implements Classifier<double[]> {
final double impurity = impurity(count, numSamples, _rule);
- final int p = _attributes.length;
- final int[] variableIndex = new int[p];
- for (int i = 0; i < p; i++) {
- variableIndex[i] = i;
- }
- if (_numVars < p) {
- SmileExtUtils.shuffle(variableIndex, _rnd);
- }
-
- final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
+ final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows())
: null;
final int[] falseCount = new int[_k];
- for (int j = 0; j < _numVars; j++) {
- Node split = findBestSplit(numSamples, count, falseCount, impurity,
- variableIndex[j], samples);
+ for (int varJ : variableIndex(x, bags)) {
+ final Node split = findBestSplit(numSamples, count, falseCount, impurity, varJ,
+ samples);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
node.splitFeatureType = split.splitFeatureType;
@@ -491,7 +486,33 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- return (node.splitFeature != -1);
+ return node.splitFeature != -1;
+ }
+
+ @Nonnull
+ private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+ final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong());
+ if (x.isSparse()) {
+ final RoaringBitmap cols = new RoaringBitmap();
+ final VectorProcedure proc = new VectorProcedure() {
+ public void apply(final int col) {
+ cols.add(col);
+ }
+ };
+ for (final int row : bags) {
+ x.eachColumnIndexInRow(row, proc);
+ }
+ cols.forEach(new IntConsumer() {
+ public void accept(final int k) {
+ sampler.add(k);
+ }
+ });
+ } else {
+ for (int i = 0, size = _attributes.length; i < size; i++) {
+ sampler.add(i);
+ }
+ }
+ return sampler.getSample();
}
private boolean sampleCount(@Nonnull final int[] count) {
@@ -530,7 +551,11 @@ public final class DecisionTree implements Classifier<double[]> {
for (int i = 0, size = bags.length; i < size; i++) {
int index = bags[i];
- int x_ij = (int) x[index][j];
+ final double v = x.get(index, j, Double.NaN);
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ int x_ij = (int) v;
trueCount[x_ij][y[index]]++;
}
@@ -563,21 +588,28 @@ public final class DecisionTree implements Classifier<double[]> {
}
} else if (_attributes[j].type == AttributeType.NUMERIC) {
final int[] trueCount = new int[_k];
- double prevx = Double.NaN;
- int prevy = -1;
-
- assert (samples != null);
- for (final int i : _order[j]) {
- final int sample = samples[i];
- if (sample > 0) {
- final double x_ij = x[i][j];
+
+ _order.eachNonNullInColumn(j, new VectorProcedure() {
+ double prevx = Double.NaN;
+ int prevy = -1;
+
+ public void apply(final int row, final int i) {
+ final int sample = samples[i];
+ if (sample == 0) {
+ return;
+ }
+
+ final double x_ij = x.get(i, j, Double.NaN);
+ if (Double.isNaN(x_ij)) {
+ return;
+ }
final int y_i = y[i];
if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- continue;
+ return;
}
final int tc = Math.sum(trueCount);
@@ -588,7 +620,7 @@ public final class DecisionTree implements Classifier<double[]> {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- continue;
+ return;
}
for (int l = 0; l < _k; l++) {
@@ -612,8 +644,8 @@ public final class DecisionTree implements Classifier<double[]> {
prevx = x_ij;
prevy = y_i;
trueCount[y_i] += sample;
- }
- }
+ }//apply()
+ });
} else {
throw new IllegalStateException("Unsupported attribute type: "
+ _attributes[j].type);
@@ -634,7 +666,9 @@ public final class DecisionTree implements Classifier<double[]> {
int childBagSize = (int) (bags.length * 0.4);
IntArrayList trueBags = new IntArrayList(childBagSize);
IntArrayList falseBags = new IntArrayList(childBagSize);
- int tc = splitSamples(trueBags, falseBags);
+ double[] trueChildPosteriori = new double[_k];
+ double[] falseChildPosteriori = new double[_k];
+ int tc = splitSamples(trueBags, falseBags, trueChildPosteriori, falseChildPosteriori);
int fc = bags.length - tc;
this.bags = null; // help GC for recursive call
@@ -647,7 +681,12 @@ public final class DecisionTree implements Classifier<double[]> {
return false;
}
- node.trueChild = new Node(node.trueChildOutput);
+ for (int i = 0; i < _k; i++) {
+ trueChildPosteriori[i] /= tc;
+ falseChildPosteriori[i] /= fc;
+ }
+
+ node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori);
TrainNode trueChild = new TrainNode(node.trueChild, x, y, trueBags.toArray(), depth + 1);
trueBags = null; // help GC for recursive call
if (tc >= _minSplit && trueChild.findBestSplit()) {
@@ -658,7 +697,7 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- node.falseChild = new Node(node.falseChildOutput);
+ node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori);
TrainNode falseChild = new TrainNode(node.falseChild, x, y, falseBags.toArray(),
depth + 1);
falseBags = null; // help GC for recursive call
@@ -670,27 +709,33 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- _importance[node.splitFeature] += node.splitScore;
+ _importance.incr(node.splitFeature, node.splitScore);
+ node.posteriori = null; // posteriori is not needed for non-leaf nodes
return true;
}
/**
+ * @param falseChildPosteriori
+ * @param trueChildPosteriori
* @return the number of true samples
*/
private int splitSamples(@Nonnull final IntArrayList trueBags,
- @Nonnull final IntArrayList falseBags) {
+ @Nonnull final IntArrayList falseBags, @Nonnull final double[] trueChildPosteriori,
+ @Nonnull final double[] falseChildPosteriori) {
int tc = 0;
if (node.splitFeatureType == AttributeType.NOMINAL) {
final int splitFeature = node.splitFeature;
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] == splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) == splitValue) {
trueBags.add(index);
+ trueChildPosteriori[y[index]]++;
tc++;
} else {
falseBags.add(index);
+ falseChildPosteriori[y[index]]++;
}
}
} else if (node.splitFeatureType == AttributeType.NUMERIC) {
@@ -698,11 +743,13 @@ public final class DecisionTree implements Classifier<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] <= splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
trueBags.add(index);
+ trueChildPosteriori[y[index]]++;
tc++;
} else {
falseBags.add(index);
+ falseChildPosteriori[y[index]]++;
}
}
} else {
@@ -714,7 +761,6 @@ public final class DecisionTree implements Classifier<double[]> {
}
-
/**
* Returns the impurity of a node.
*
@@ -731,8 +777,9 @@ public final class DecisionTree implements Classifier<double[]> {
case GINI: {
impurity = 1.0;
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- double p = (double) count[i] / n;
+ final int count_i = count[i];
+ if (count_i > 0) {
+ double p = (double) count_i / n;
impurity -= p * p;
}
}
@@ -740,8 +787,9 @@ public final class DecisionTree implements Classifier<double[]> {
}
case ENTROPY: {
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- double p = (double) count[i] / n;
+ final int count_i = count[i];
+ if (count_i > 0) {
+ double p = (double) count_i / n;
impurity -= p * Math.log2(p);
}
}
@@ -750,8 +798,9 @@ public final class DecisionTree implements Classifier<double[]> {
case CLASSIFICATION_ERROR: {
impurity = 0.d;
for (int i = 0; i < count.length; i++) {
- if (count[i] > 0) {
- impurity = Math.max(impurity, (double) count[i] / n);
+ final int count_i = count[i];
+ if (count_i > 0) {
+ impurity = Math.max(impurity, (double) count_i / n);
}
}
impurity = Math.abs(1.d - impurity);
@@ -762,14 +811,14 @@ public final class DecisionTree implements Classifier<double[]> {
return impurity;
}
- public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
+ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
int numLeafs) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null);
}
- public DecisionTree(@Nullable Attribute[] attributes, @Nullable double[][] x,
- @Nullable int[] y, int numLeafs, @Nullable smile.math.Random rand) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
+ public DecisionTree(@Nullable Attribute[] attributes, @Nullable Matrix x, @Nullable int[] y,
+ int numLeafs, @Nullable PRNG rand) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand);
}
/**
@@ -778,21 +827,20 @@ public final class DecisionTree implements Classifier<double[]> {
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
- * @param numVars the number of input variables to pick to split on at each node. It seems that
- * dim/3 give generally good performance, where dim is the number of variables.
+ * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim
+ * is the number of variables.
* @param maxLeafs the maximum number of leaf nodes in the tree.
* @param minSplits the number of minimum elements in a node to split
* @param minLeafSize the minimum size of leaf nodes.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
+ * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted.
* @param bags the sample set of instances for stochastic learning.
* @param rule the splitting rule.
* @param seed
*/
- public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y,
+ public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y,
int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
- @Nullable int[] bags, @Nullable int[][] order, @Nonnull SplitRule rule,
- @Nullable smile.math.Random rand) {
+ @Nullable int[] bags, @Nullable ColumnMajorIntMatrix order, @Nonnull SplitRule rule,
+ @Nullable PRNG rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
this._k = Math.max(y) + 1;
@@ -801,7 +849,7 @@ public final class DecisionTree implements Classifier<double[]> {
}
this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (attributes.length != x[0].length) {
+ if (attributes.length != x.numColumns()) {
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}
@@ -813,8 +861,8 @@ public final class DecisionTree implements Classifier<double[]> {
this._minLeafSize = minLeafSize;
this._rule = rule;
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
- this._importance = new double[_attributes.length];
- this._rnd = (rand == null) ? new smile.math.Random() : rand;
+ this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length);
+ this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
final int n = y.length;
final int[] count = new int[_k];
@@ -825,13 +873,17 @@ public final class DecisionTree implements Classifier<double[]> {
count[y[i]]++;
}
} else {
- for (int i = 0; i < n; i++) {
+ for (int i = 0, size = bags.length; i < size; i++) {
int index = bags[i];
count[y[index]]++;
}
}
- this._root = new Node(Math.whichMax(count));
+ final double[] posteriori = new double[_k];
+ for (int i = 0; i < _k; i++) {
+ posteriori[i] = (double) count[i] / n;
+ }
+ this._root = new Node(Math.whichMax(count), posteriori);
final TrainNode trainRoot = new TrainNode(_root, x, y, bags, 1);
if (maxLeafs == Integer.MAX_VALUE) {
@@ -858,13 +910,13 @@ public final class DecisionTree implements Classifier<double[]> {
}
}
- private static void checkArgument(@Nonnull double[][] x, @Nonnull int[] y, int numVars,
+ private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
- if (x.length != y.length) {
+ if (x.numRows() != y.length) {
throw new IllegalArgumentException(String.format(
- "The sizes of X and Y don't match: %d != %d", x.length, y.length));
+ "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
- if (numVars <= 0 || numVars > x[0].length) {
+ if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
@@ -885,28 +937,31 @@ public final class DecisionTree implements Classifier<double[]> {
}
/**
- * Returns the variable importance. Every time a split of a node is made on variable the (GINI,
- * information gain, etc.) impurity criterion for the two descendent nodes is less than the
- * parent node. Adding up the decreases for each individual variable over the tree gives a
- * simple measure of variable importance.
+ * Returns the variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the
+ * two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of
+ * variable importance.
*
* @return the variable importance
*/
- public double[] importance() {
+ @Nonnull
+ public Vector importance() {
return _importance;
}
+ @VisibleForTesting
+ public int predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
@Override
- public int predict(final double[] x) {
+ public int predict(@Nonnull final Vector x) {
return _root.predict(x);
}
/**
- * Predicts the class label of an instance and also calculate a posteriori probabilities. Not
- * supported.
+ * Predicts the class label of an instance and also calculate a posteriori probabilities. Not supported.
*/
- @Override
- public int predict(double[] x, double[] posteriori) {
+ public int predict(Vector x, double[] posteriori) {
throw new UnsupportedOperationException("Not supported.");
}
@@ -916,14 +971,6 @@ public final class DecisionTree implements Classifier<double[]> {
return buf.toString();
}
- public String predictOpCodegen(String sep) {
- List<String> opslist = new ArrayList<String>();
- _root.opCodegen(opslist, 0);
- opslist.add("call end");
- String scripts = StringUtils.concat(opslist, sep);
- return scripts;
- }
-
@Nonnull
public byte[] predictSerCodegen(boolean compress) throws HiveException {
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 3a0924e..a380a11 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -19,24 +19,27 @@
package hivemall.smile.classification;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
+import hivemall.utils.math.MathUtils;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
-import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -63,7 +66,7 @@ import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
@Description(name = "train_gradient_tree_boosting_classifier",
- value = "_FUNC_(double[] features, int label [, string options]) - "
+ value = "_FUNC_(array<double|string> features, int label [, string options]) - "
+ "Returns a relation consists of "
+ "<int iteration, int model_type, array<string> pred_models, double intercept, "
+ "double shrinkage, array<double> var_importance, float oob_error_rate>")
@@ -74,7 +77,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector labelOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private IntArrayList labels;
/**
* The number of trees for each task
@@ -104,7 +108,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
@Nullable
private Reporter _progressReporter;
@@ -134,10 +137,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
return opts;
}
@@ -149,8 +148,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
double eta = 0.05d, subsample = 0.7d;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
- boolean compress = true;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -171,10 +168,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
- if (cl.hasOption("disable_compression")) {
- compress = false;
- }
}
this._numTrees = trees;
@@ -187,7 +180,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
return cl;
}
@@ -197,19 +189,29 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
if (argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(
getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+ + " takes 2 or 3 arguments: array<double|string> features, int label [, const string options]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.labels = new IntArrayList(1024);
ArrayList<String> fieldNames = new ArrayList<String>(6);
@@ -217,8 +219,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
fieldNames.add("iteration");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("pred_models");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
fieldNames.add("intercept");
@@ -238,13 +238,36 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
-
- featuresList.add(features);
labels.add(label);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -252,14 +275,15 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
"hivemall.smile.GradientTreeBoostingClassifier$Counter", "iteration");
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
- int[] y = labels.toArray();
- this.labels = null;
+ if (!labels.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
+ int[] y = labels.toArray();
+ this.labels = null;
- // run training
- train(x, y);
+ // run training
+ train(x, y);
+ }
// clean up
this.featureListOI = null;
@@ -287,25 +311,25 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
* @param x features
* @param y label
*/
- private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException {
+ final int numRows = x.numRows();
+ if (numRows != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numRows, y.length));
}
checkOptions();
this._attributes = SmileExtUtils.attributeTypes(_attributes, x);
// Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ x = SmileExtUtils.shuffle(x, y, _seed);
final int k = smile.math.Math.max(y) + 1;
if (k < 2) {
throw new UDFArgumentException("Only one class or negative class labels.");
}
if (k == 2) {
- int n = x.length;
- final int[] y2 = new int[n];
- for (int i = 0; i < n; i++) {
+ final int[] y2 = new int[numRows];
+ for (int i = 0; i < numRows; i++) {
if (y[i] == 1) {
y2[i] = 1;
} else {
@@ -318,7 +342,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
}
- private void train2(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
+ private void train2(@Nonnull final Matrix x, @Nonnull final int[] y) throws HiveException {
final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
logger.info("k: " + 2 + ", numTrees: " + _numTrees + ", shirinkage: " + _eta
@@ -327,7 +351,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
+ _maxLeafNodes + ", seed: " + _seed);
}
- final int numInstances = x.length;
+ final int numInstances = x.numRows();
final int numSamples = (int) Math.round(numInstances * _subsample);
final double[] h = new double[numInstances]; // current F(x_i)
@@ -340,7 +364,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
h[i] = intercept;
}
- final int[][] order = SmileExtUtils.sort(_attributes, x);
+ final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput output = new L2NodeOutput(response);
final BitSet sampled = new BitSet(numInstances);
@@ -351,10 +375,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
- : new smile.math.Random(_seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final Vector xProbe = x.rowVector();
for (int m = 0; m < _numTrees; m++) {
reportProgress(_progressReporter);
@@ -373,7 +398,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
_maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, bag, output, rnd2);
for (int i = 0; i < numInstances; i++) {
- h[i] += _eta * tree.predict(x[i]);
+ x.getRow(i, xProbe);
+ h[i] += _eta * tree.predict(xProbe);
}
// out-of-bag error estimate
@@ -398,7 +424,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
/**
* Train L-k tree boost.
*/
- private void traink(final double[][] x, final int[] y, final int k) throws HiveException {
+ private void traink(final Matrix x, final int[] y, final int k) throws HiveException {
final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x);
if (logger.isInfoEnabled()) {
logger.info("k: " + k + ", numTrees: " + _numTrees + ", shirinkage: " + _eta
@@ -407,14 +433,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
+ ", maxLeafs: " + _maxLeafNodes + ", seed: " + _seed);
}
- final int numInstances = x.length;
+ final int numInstances = x.numRows();
final int numSamples = (int) Math.round(numInstances * _subsample);
final double[][] h = new double[k][numInstances]; // boost tree output.
final double[][] p = new double[k][numInstances]; // posteriori probabilities.
final double[][] response = new double[k][numInstances]; // pseudo response.
- final int[][] order = SmileExtUtils.sort(_attributes, x);
+ final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x);
final RegressionTree.NodeOutput[] output = new LKNodeOutput[k];
for (int i = 0; i < k; i++) {
output[i] = new LKNodeOutput(response[i], k);
@@ -422,19 +448,16 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final BitSet sampled = new BitSet(numInstances);
final int[] bag = new int[numSamples];
- final int[] perm = new int[numSamples];
- for (int i = 0; i < numSamples; i++) {
- perm[i] = i;
- }
+ final int[] perm = MathUtils.permutation(numInstances);
long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
- : new smile.math.Random(_seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
// out-of-bag prediction
final int[] prediction = new int[numInstances];
-
+ final Vector xProbe = x.rowVector();
for (int m = 0; m < _numTrees; m++) {
for (int i = 0; i < numInstances; i++) {
double max = Double.NEGATIVE_INFINITY;
@@ -490,7 +513,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
trees[j] = tree;
for (int i = 0; i < numInstances; i++) {
- double h_ji = h_j[i] + _eta * tree.predict(x[i]);
+ x.getRow(i, xProbe);
+ double h_ji = h_j[i] + _eta * tree.predict(xProbe);
h_j[i] += h_ji;
if (h_ji > max_h) {
max_h = h_ji;
@@ -524,7 +548,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
*/
private void forward(final int m, final double intercept, final double shrinkage,
final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException {
- Text[] models = getModel(trees, _outputType);
+ Text[] models = getModel(trees);
double[] importance = new double[_attributes.length];
for (RegressionTree tree : trees) {
@@ -534,14 +558,13 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
}
}
- Object[] forwardObjs = new Object[7];
+ Object[] forwardObjs = new Object[6];
forwardObjs[0] = new IntWritable(m);
- forwardObjs[1] = new IntWritable(_outputType.getId());
- forwardObjs[2] = models;
- forwardObjs[3] = new DoubleWritable(intercept);
- forwardObjs[4] = new DoubleWritable(shrinkage);
- forwardObjs[5] = WritableUtils.toWritableList(importance);
- forwardObjs[6] = new FloatWritable(oobErrorRate);
+ forwardObjs[1] = models;
+ forwardObjs[2] = new DoubleWritable(intercept);
+ forwardObjs[3] = new DoubleWritable(shrinkage);
+ forwardObjs[4] = WritableUtils.toWritableList(importance);
+ forwardObjs[5] = new FloatWritable(oobErrorRate);
forward(forwardObjs);
@@ -551,67 +574,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
logger.info("Forwarded the output of " + m + "-th Boosting iteration out of " + _numTrees);
}
- private static Text[] getModel(@Nonnull final RegressionTree[] trees,
- @Nonnull final ModelType outputType) throws HiveException {
+ @Nonnull
+ private static Text[] getModel(@Nonnull final RegressionTree[] trees) throws HiveException {
final int m = trees.length;
final Text[] models = new Text[m];
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- for (int i = 0; i < m; i++) {
- byte[] b = trees[i].predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- models[i] = new Text(b);
- }
- break;
- }
- case opscode:
- case opscode_compressed: {
- for (int i = 0; i < m; i++) {
- String s = trees[i].predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- models[i] = new Text(b);
- } else {
- models[i] = new Text(s);
- }
- }
- break;
- }
- case javascript:
- case javascript_compressed: {
- for (int i = 0; i < m; i++) {
- String s = trees[i].predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- models[i] = new Text(b);
- } else {
- models[i] = new Text(s);
- }
- }
- break;
- }
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
+ for (int i = 0; i < m; i++) {
+ byte[] b = trees[i].predictSerCodegen(true);
+ b = Base91.encode(b);
+ models[i] = new Text(b);
}
return models;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
new file mode 100644
index 0000000..84ef244
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.classification;
+
+import javax.annotation.Nonnull;
+
+public interface PredictionHandler {
+
+ void handle(int output, @Nonnull double[] posteriori);
+
+}
[09/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 03db65c..5a831df 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -19,30 +19,43 @@
package hivemall.smile.classification;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.MatrixUtils;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.matrix.ints.DoKIntMatrix;
+import hivemall.math.matrix.ints.IntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
+import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
-import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.BitSet;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
@@ -52,7 +65,9 @@ import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -67,9 +82,9 @@ import org.apache.hadoop.mapred.Reporter;
@Description(
name = "train_randomforest_classifier",
- value = "_FUNC_(double[] features, int label [, string options]) - "
+ value = "_FUNC_(array<double|string> features, int label [, const array<double> classWeights, const string options]) - "
+ "Returns a relation consists of "
- + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
+ + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests, double weight>")
public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class);
@@ -77,8 +92,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector labelOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private IntArrayList labels;
+
/**
* The number of trees for each task
*/
@@ -99,8 +116,12 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
private SplitRule _splitRule;
+ private boolean _stratifiedSampling;
+ private double _subsample;
+
+ @Nullable
+ private double[] _classWeight;
@Nullable
private Reporter _progressReporter;
@@ -126,11 +147,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
opts.addOption("rule", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
+ opts.addOption("stratified", "stratified_sampling", false,
+ "Enable Stratified sampling for unbalanced data");
+ opts.addOption("subsample", true, "Sampling rate in range (0.0,1.0]");
return opts;
}
@@ -141,9 +161,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
float numVars = -1.f;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
SplitRule splitRule = SplitRule.GINI;
- boolean compress = true;
+ double[] classWeight = null;
+ boolean stratifiedSampling = false;
+ double subsample = 1.0d;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -162,10 +183,26 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI"));
- if (cl.hasOption("disable_compression")) {
- compress = false;
+ stratifiedSampling = cl.hasOption("stratified_sampling");
+ subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), 1.0d);
+ Preconditions.checkArgument(subsample > 0.d && subsample <= 1.0d,
+ UDFArgumentException.class, "Invalid -subsample value: " + subsample);
+
+ if (argOIs.length >= 4) {
+ classWeight = HiveUtils.getConstDoubleArray(argOIs[3]);
+ if (classWeight != null) {
+ for (int i = 0; i < classWeight.length; i++) {
+ double v = classWeight[i];
+ if (Double.isNaN(v)) {
+ classWeight[i] = 1.0d;
+ } else if (v <= 0.d) {
+ throw new UDFArgumentTypeException(3,
+ "each classWeight must be greather than 0: "
+ + Arrays.toString(classWeight));
+ }
+ }
+ }
}
}
@@ -177,43 +214,60 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
this._splitRule = splitRule;
+ this._stratifiedSampling = stratifiedSampling;
+ this._subsample = subsample;
+ this._classWeight = classWeight;
return cl;
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length != 2 && argOIs.length != 3) {
+ if (argOIs.length < 2 || argOIs.length > 4) {
throw new UDFArgumentException(
- getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, int label [, const string options]: "
+ "_FUNC_ takes 2 ~ 4 arguments: array<double|string> features, int label [, const string options, const array<double> classWeight]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.labels = new IntArrayList(1024);
- ArrayList<String> fieldNames = new ArrayList<String>(6);
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
+ final ArrayList<String> fieldNames = new ArrayList<String>(6);
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- fieldNames.add("pred_model");
+ fieldNames.add("model_weight");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ if (denseInput) {
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ } else {
+ fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector,
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
fieldNames.add("oob_errors");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("oob_tests");
@@ -227,13 +281,36 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
-
- featuresList.add(features);
labels.add(label);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -242,10 +319,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
"finishedTreeBuildTasks");
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- if (numExamples > 0) {
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
+ if (!labels.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
int[] y = labels.toArray();
this.labels = null;
@@ -277,15 +353,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
* @param numVars The number of variables to pick up in each node.
* @param seed The seed number for Random Forest
*/
- private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException {
+ final int numExamples = x.numRows();
+ if (numExamples != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numExamples, y.length));
}
checkOptions();
- // Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ // Shuffle training samples
+ x = SmileExtUtils.shuffle(x, y, _seed);
int[] labels = SmileExtUtils.classLables(y);
Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
@@ -297,9 +374,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
+ _maxLeafNodes + ", splitRule: " + _splitRule + ", seed: " + _seed);
}
- final int numExamples = x.length;
- int[][] prediction = new int[numExamples][labels.length]; // placeholder for out-of-bag prediction
- int[][] order = SmileExtUtils.sort(attributes, x);
+ IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction
+ ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
@@ -321,17 +397,19 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* Synchronized because {@link #forward(Object)} should be called from a single thread.
+ *
+ * @param accuracy
*/
synchronized void forward(final int taskId, @Nonnull final Text model,
- @Nonnull final double[] importance, final int[] y, final int[][] prediction,
- final boolean lastTask) throws HiveException {
+ @Nonnull final Vector importance, @Nonnegative final double accuracy, final int[] y,
+ @Nonnull final IntMatrix prediction, final boolean lastTask) throws HiveException {
int oobErrors = 0;
int oobTests = 0;
if (lastTask) {
// out-of-bag error estimate
for (int i = 0; i < y.length; i++) {
- final int pred = smile.math.Math.whichMax(prediction[i]);
- if (prediction[i][pred] > 0) {
+ final int pred = MatrixUtils.whichMax(prediction, i);
+ if (pred != -1 && prediction.get(i, pred) > 0) {
oobTests++;
if (pred != y[i]) {
oobErrors++;
@@ -340,12 +418,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
}
}
- String modelId = RandomUtils.getUUID();
final Object[] forwardObjs = new Object[6];
+ String modelId = RandomUtils.getUUID();
forwardObjs[0] = new Text(modelId);
- forwardObjs[1] = new IntWritable(_outputType.getId());
+ forwardObjs[1] = new DoubleWritable(accuracy);
forwardObjs[2] = model;
- forwardObjs[3] = WritableUtils.toWritableList(importance);
+ if (denseInput) {
+ forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
+ } else {
+ final Map<IntWritable, DoubleWritable> map = new HashMap<IntWritable, DoubleWritable>(
+ importance.size());
+ importance.each(new VectorProcedure() {
+ public void apply(int i, double value) {
+ map.put(new IntWritable(i), new DoubleWritable(value));
+ }
+ });
+ forwardObjs[3] = map;
+ }
forwardObjs[4] = new IntWritable(oobErrors);
forwardObjs[5] = new IntWritable(oobTests);
forward(forwardObjs);
@@ -363,20 +452,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* Attribute properties.
*/
+ @Nonnull
private final Attribute[] _attributes;
/**
* Training instances.
*/
- private final double[][] _x;
+ @Nonnull
+ private final Matrix _x;
/**
* Training sample labels.
*/
+ @Nonnull
private final int[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ @Nonnull
+ private final ColumnMajorIntMatrix _order;
/**
* The number of variables to pick up in each node.
*/
@@ -384,16 +476,21 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
/**
* The out-of-bag predictions.
*/
- private final int[][] _prediction;
+ @Nonnull
+ @GuardedBy("_udtf")
+ private final IntMatrix _prediction;
+ @Nonnull
private final RandomForestClassifierUDTF _udtf;
private final int _taskId;
private final long _seed;
+ @Nonnull
private final AtomicInteger _remainingTasks;
- TrainingTask(RandomForestClassifierUDTF udtf, int taskId, Attribute[] attributes,
- double[][] x, int[] y, int numVars, int[][] order, int[][] prediction, long seed,
- AtomicInteger remainingTasks) {
+ TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId,
+ @Nonnull Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numVars,
+ @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix prediction, long seed,
+ @Nonnull AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
this._attributes = attributes;
@@ -408,98 +505,107 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
@Override
public Integer call() throws HiveException {
- long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
- _seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
- final int N = _x.length;
+ long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final int N = _x.numRows();
// Training samples draw with replacement.
- final int[] bags = new int[N];
final BitSet sampled = new BitSet(N);
- for (int i = 0; i < N; i++) {
- int index = rnd1.nextInt(N);
- bags[i] = index;
- sampled.set(index);
- }
+ final int[] bags = sampling(sampled, N, rnd1);
DecisionTree tree = new DecisionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth,
_udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, bags, _order,
_udtf._splitRule, rnd2);
// out-of-bag prediction
+ int oob = 0;
+ int correct = 0;
+ final Vector xProbe = _x.rowVector();
for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
- final int p = tree.predict(_x[i]);
- synchronized (_prediction[i]) {
- _prediction[i][p]++;
+ oob++;
+ _x.getRow(i, xProbe);
+ final int p = tree.predict(xProbe);
+ if (p == _y[i]) {
+ correct++;
+ }
+ synchronized (_udtf) {
+ _prediction.incr(i, p);
}
}
- Text model = getModel(tree, _udtf._outputType);
- double[] importance = tree.importance();
+ Text model = getModel(tree);
+ Vector importance = tree.importance();
+ double accuracy = (oob == 0) ? 1.0d : (double) correct / oob;
int remain = _remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
- _udtf.forward(_taskId + 1, model, importance, _y, _prediction, lastTask);
+ _udtf.forward(_taskId + 1, model, importance, accuracy, _y, _prediction, lastTask);
return Integer.valueOf(remain);
}
- private static Text getModel(@Nonnull final DecisionTree tree,
- @Nonnull final ModelType outputType) throws HiveException {
- final Text model;
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- byte[] b = tree.predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- model = new Text(b);
- break;
- }
- case opscode:
- case opscode_compressed: {
- String s = tree.predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
+ @Nonnull
+ private int[] sampling(@Nonnull final BitSet sampled, final int N, @Nonnull PRNG rnd) {
+ return _udtf._stratifiedSampling ? stratifiedSampling(sampled, N, _udtf._subsample, rnd)
+ : uniformSampling(sampled, N, _udtf._subsample, rnd);
+ }
+
+ @Nonnull
+ private static int[] uniformSampling(@Nonnull final BitSet sampled, final int N,
+ final double subsample, final PRNG rnd) {
+ final int size = (int) Math.round(N * subsample);
+ final int[] bags = new int[N];
+ for (int i = 0; i < size; i++) {
+ int index = rnd.nextInt(N);
+ bags[i] = index;
+ sampled.set(index);
+ }
+ return bags;
+ }
+
+ /**
+ * Stratified sampling for unbalanced data.
+ *
+ * @link https://en.wikipedia.org/wiki/Stratified_sampling
+ */
+ @Nonnull
+ private int[] stratifiedSampling(@Nonnull final BitSet sampled, final int N,
+ final double subsample, final PRNG rnd) {
+ final IntArrayList bagsList = new IntArrayList(N);
+ final int k = smile.math.Math.max(_y) + 1;
+ final IntArrayList cj = new IntArrayList(N / k);
+ for (int l = 0; l < k; l++) {
+ int nj = 0;
+ for (int i = 0; i < N; i++) {
+ if (_y[i] == l) {
+ cj.add(i);
+ nj++;
}
- break;
}
- case javascript:
- case javascript_compressed: {
- String s = tree.predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
+ if (subsample != 1.0d) {
+ nj = (int) Math.round(nj * subsample);
+ }
+ final int size = (_udtf._classWeight == null) ? nj : (int) Math.round(nj
+ * _udtf._classWeight[l]);
+ for (int j = 0; j < size; j++) {
+ int xi = rnd.nextInt(nj);
+ int index = cj.get(xi);
+ bagsList.add(index);
+ sampled.set(index);
}
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
+ cj.clear();
}
- return model;
+ int[] bags = bagsList.toArray(true);
+ SmileExtUtils.shuffle(bags, rnd);
+ return bags;
+ }
+
+ @Nonnull
+ private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException {
+ byte[] b = tree.predictSerCodegen(true);
+ b = Base91.encode(b);
+ return new Text(b);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/data/Attribute.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java b/core/src/main/java/hivemall/smile/data/Attribute.java
index be6651a..6569726 100644
--- a/core/src/main/java/hivemall/smile/data/Attribute.java
+++ b/core/src/main/java/hivemall/smile/data/Attribute.java
@@ -18,6 +18,9 @@
*/
package hivemall.smile.data;
+import hivemall.annotations.Immutable;
+import hivemall.annotations.Mutable;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -25,11 +28,9 @@ import java.io.ObjectOutput;
public abstract class Attribute {
public final AttributeType type;
- public final int attrIndex;
- Attribute(AttributeType type, int attrIndex) {
+ Attribute(AttributeType type) {
this.type = type;
- this.attrIndex = attrIndex;
}
public void setSize(int size) {
@@ -44,24 +45,23 @@ public abstract class Attribute {
}
public void writeTo(ObjectOutput out) throws IOException {
- out.writeInt(type.getTypeId());
- out.writeInt(attrIndex);
+ out.writeByte(type.getTypeId());
}
public enum AttributeType {
- NUMERIC(1), NOMINAL(2);
+ NUMERIC((byte) 1), NOMINAL((byte) 2);
- private final int id;
+ private final byte id;
- private AttributeType(int id) {
+ private AttributeType(byte id) {
this.id = id;
}
- public int getTypeId() {
+ public byte getTypeId() {
return id;
}
- public static AttributeType resolve(int id) {
+ public static AttributeType resolve(byte id) {
final AttributeType type;
switch (id) {
case 1:
@@ -78,25 +78,27 @@ public abstract class Attribute {
}
+ @Immutable
public static final class NumericAttribute extends Attribute {
- public NumericAttribute(int attrIndex) {
- super(AttributeType.NUMERIC, attrIndex);
+ public NumericAttribute() {
+ super(AttributeType.NUMERIC);
}
@Override
public String toString() {
- return "NumericAttribute [type=" + type + ", attrIndex=" + attrIndex + "]";
+ return "NumericAttribute [type=" + type + "]";
}
}
+ @Mutable
public static final class NominalAttribute extends Attribute {
private int size;
- public NominalAttribute(int attrIndex) {
- super(AttributeType.NOMINAL, attrIndex);
+ public NominalAttribute() {
+ super(AttributeType.NOMINAL);
this.size = -1;
}
@@ -118,25 +120,23 @@ public abstract class Attribute {
@Override
public String toString() {
- return "NominalAttribute [size=" + size + ", type=" + type + ", attrIndex=" + attrIndex
- + "]";
+ return "NominalAttribute [size=" + size + ", type=" + type + "]";
}
}
public static Attribute readFrom(ObjectInput in) throws IOException {
- int typeId = in.readInt();
- int attrIndex = in.readInt();
-
final Attribute attr;
+
+ byte typeId = in.readByte();
final AttributeType type = AttributeType.resolve(typeId);
switch (type) {
case NUMERIC: {
- attr = new NumericAttribute(attrIndex);
+ attr = new NumericAttribute();
break;
}
case NOMINAL: {
- attr = new NominalAttribute(attrIndex);
+ attr = new NominalAttribute();
int size = in.readInt();
attr.setSize(size);
break;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index ebb58c6..557df21 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -19,22 +19,25 @@
package hivemall.smile.regression;
import hivemall.UDTFWithOptions;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.Vector;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
-import hivemall.smile.vm.StackMachine;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
-import hivemall.utils.collections.DoubleArrayList;
+import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
-import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
@@ -42,6 +45,7 @@ import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -69,7 +73,7 @@ import org.apache.hadoop.mapred.Reporter;
@Description(
name = "train_randomforest_regression",
- value = "_FUNC_(double[] features, double target [, string options]) - "
+ value = "_FUNC_(array<double|string> features, double target [, string options]) - "
+ "Returns a relation consists of "
+ "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>")
public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@@ -79,7 +83,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector targetOI;
- private List<double[]> featuresList;
+ private boolean denseInput;
+ private MatrixBuilder matrixBuilder;
private DoubleArrayList targets;
/**
* The number of trees for each task
@@ -101,7 +106,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private int _minSamplesLeaf;
private long _seed;
private Attribute[] _attributes;
- private ModelType _outputType;
@Nullable
private Reporter _progressReporter;
@@ -131,10 +135,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types "
+ "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
- opts.addOption("output", "output_type", true,
- "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]");
- opts.addOption("disable_compression", false,
- "Whether to disable compression of the output script [default: false]");
return opts;
}
@@ -145,8 +145,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
float numVars = -1.f;
Attribute[] attrs = null;
long seed = -1L;
- String output = "serialization";
- boolean compress = true;
CommandLine cl = null;
if (argOIs.length >= 3) {
@@ -165,10 +163,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
minSamplesLeaf);
seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
- output = cl.getOptionValue("output", output);
- if (cl.hasOption("disable_compression")) {
- compress = false;
- }
}
this._numTrees = trees;
@@ -179,7 +173,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
this._minSamplesLeaf = minSamplesLeaf;
this._seed = seed;
this._attributes = attrs;
- this._outputType = ModelType.resolve(output, compress);
return cl;
}
@@ -189,19 +182,29 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
if (argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(
getClass().getSimpleName()
- + " takes 2 or 3 arguments: double[] features, double target [, const string options]: "
+ + " takes 2 or 3 arguments: array<double|string> features, double target [, const string options]: "
+ argOIs.length);
}
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ this.matrixBuilder = new CSRMatrixBuilder(8192);
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
+ }
this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
processOptions(argOIs);
- this.featuresList = new ArrayList<double[]>(1024);
this.targets = new DoubleArrayList(1024);
ArrayList<String> fieldNames = new ArrayList<String>(5);
@@ -209,8 +212,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model_type");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("model_err");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("pred_model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
fieldNames.add("var_importance");
@@ -228,13 +231,36 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
if (args[0] == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI);
+ parseFeatures(args[0], matrixBuilder);
double target = PrimitiveObjectInspectorUtils.getDouble(args[1], targetOI);
-
- featuresList.add(features);
targets.add(target);
}
+ private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ builder.nextColumn(i, v);
+ }
+ } else {
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String fv = o.toString();
+ builder.nextColumn(fv);
+ }
+ }
+ builder.nextRow();
+ }
+
@Override
public void close() throws HiveException {
this._progressReporter = getReporter();
@@ -250,10 +276,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
reportProgress(_progressReporter);
- int numExamples = featuresList.size();
- if (numExamples > 0) {
- double[][] x = featuresList.toArray(new double[numExamples][]);
- this.featuresList = null;
+ if (!targets.isEmpty()) {
+ Matrix x = matrixBuilder.buildMatrix();
+ this.matrixBuilder = null;
double[] y = targets.toArray();
this.targets = null;
@@ -285,15 +310,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
* @param _numVars The number of variables to pick up in each node.
* @param _seed The seed number for Random Forest
*/
- private void train(@Nonnull final double[][] x, @Nonnull final double[] y) throws HiveException {
- if (x.length != y.length) {
+ private void train(@Nonnull Matrix x, @Nonnull final double[] y) throws HiveException {
+ final int numExamples = x.numRows();
+ if (numExamples != y.length) {
throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d",
- x.length, y.length));
+ numExamples, y.length));
}
checkOptions();
- // Shuffle training samples
- SmileExtUtils.shuffle(x, y, _seed);
+ // Shuffle training samples
+ x = SmileExtUtils.shuffle(x, y, _seed);
Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
@@ -305,10 +331,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
+ ", seed: " + _seed);
}
- int numExamples = x.length;
double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction
int[] oob = new int[numExamples];
- int[][] order = SmileExtUtils.sort(attributes, x);
+ ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x);
AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
List<TrainingTask> tasks = new ArrayList<TrainingTask>();
for (int i = 0; i < _numTrees; i++) {
@@ -330,10 +355,13 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
/**
* Synchronized because {@link #forward(Object)} should be called from a single thread.
+ *
+ * @param error
*/
synchronized void forward(final int taskId, @Nonnull final Text model,
- @Nonnull final double[] importance, final double[] y, final double[] prediction,
- final int[] oob, final boolean lastTask) throws HiveException {
+ @Nonnull final double[] importance, @Nonnegative final double error, final double[] y,
+ final double[] prediction, final int[] oob, final boolean lastTask)
+ throws HiveException {
double oobErrors = 0.d;
int oobTests = 0;
if (lastTask) {
@@ -349,7 +377,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
String modelId = RandomUtils.getUUID();
final Object[] forwardObjs = new Object[6];
forwardObjs[0] = new Text(modelId);
- forwardObjs[1] = new IntWritable(_outputType.getId());
+ forwardObjs[1] = new DoubleWritable(error);
forwardObjs[2] = model;
forwardObjs[3] = WritableUtils.toWritableList(importance);
forwardObjs[4] = new DoubleWritable(oobErrors);
@@ -373,16 +401,15 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
/**
* Training instances.
*/
- private final double[][] _x;
+ private final Matrix _x;
/**
* Training sample labels.
*/
private final double[] _y;
/**
- * The index of training values in ascending order. Note that only numeric attributes will
- * be sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ private final ColumnMajorIntMatrix _order;
/**
* The number of variables to pick up in each node.
*/
@@ -401,8 +428,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
private final long _seed;
private final AtomicInteger _remainingTasks;
- TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes,
- double[][] x, double[] y, int numVars, int[][] order, double[] prediction,
+ TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, Matrix x,
+ double[] y, int numVars, ColumnMajorIntMatrix order, double[] prediction,
int[] oob, long seed, AtomicInteger remainingTasks) {
this._udtf = udtf;
this._taskId = taskId;
@@ -419,11 +446,11 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@Override
public Integer call() throws HiveException {
- long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random(
- _seed).nextLong();
- final smile.math.Random rnd1 = new smile.math.Random(s);
- final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong());
- final int N = _x.length;
+ long s = (this._seed == -1L) ? SmileExtUtils.generateSeed()
+ : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong();
+ final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
+ final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
+ final int N = _x.numRows();
// Training samples draw with replacement.
final int[] bags = new int[N];
@@ -441,82 +468,40 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
incrCounter(_udtf._treeConstuctionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
// out-of-bag prediction
+ int oob = 0;
+ double error = 0.d;
+ final Vector xProbe = _x.rowVector();
for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) {
- double pred = tree.predict(_x[i]);
- synchronized (_x[i]) {
+ oob++;
+ _x.getRow(i, xProbe);
+ final double pred = tree.predict(xProbe);
+ synchronized (_prediction) {
_prediction[i] += pred;
_oob[i]++;
}
+ error += Math.abs(pred - _y[i]);
+ }
+ if (oob != 0) {
+ error /= oob;
}
stopwatch.reset().start();
- Text model = getModel(tree, _udtf._outputType);
+ Text model = getModel(tree);
double[] importance = tree.importance();
tree = null; // help GC
int remain = _remainingTasks.decrementAndGet();
boolean lastTask = (remain == 0);
- _udtf.forward(_taskId + 1, model, importance, _y, _prediction, _oob, lastTask);
+ _udtf.forward(_taskId + 1, model, importance, error, _y, _prediction, _oob, lastTask);
incrCounter(_udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
return Integer.valueOf(remain);
}
- private static Text getModel(@Nonnull final RegressionTree tree,
- @Nonnull final ModelType outputType) throws HiveException {
- final Text model;
- switch (outputType) {
- case serialization:
- case serialization_compressed: {
- byte[] b = tree.predictSerCodegen(outputType.isCompressed());
- b = Base91.encode(b);
- model = new Text(b);
- break;
- }
- case opscode:
- case opscode_compressed: {
- String s = tree.predictOpCodegen(StackMachine.SEP);
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
- }
- case javascript:
- case javascript_compressed: {
- String s = tree.predictJsCodegen();
- if (outputType.isCompressed()) {
- byte[] b = s.getBytes();
- final DeflateCodec codec = new DeflateCodec(true, false);
- try {
- b = codec.compress(b);
- } catch (IOException e) {
- throw new HiveException("Failed to compressing a model", e);
- } finally {
- IOUtils.closeQuietly(codec);
- }
- b = Base91.encode(b);
- model = new Text(b);
- } else {
- model = new Text(s);
- }
- break;
- }
- default:
- throw new HiveException("Unexpected output type: " + outputType
- + ". Use javascript for the output instead");
- }
- return model;
+ @Nonnull
+ private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException {
+ byte[] b = tree.predictSerCodegen(true);
+ b = Base91.encode(b);
+ return new Text(b);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RegressionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 07887c1..da7e80b 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -33,20 +33,28 @@
*/
package hivemall.smile.regression;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.collections.sets.IntArraySet;
+import hivemall.utils.collections.sets.IntSet;
import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.StringUtils;
+import hivemall.utils.math.MathUtils;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
@@ -55,60 +63,48 @@ import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import smile.math.Math;
-import smile.math.Random;
import smile.regression.GradientTreeBoost;
import smile.regression.RandomForest;
import smile.regression.Regression;
/**
- * Decision tree for regression. A decision tree can be learned by splitting the training set into
- * subsets based on an attribute value test. This process is repeated on each derived subset in a
- * recursive manner called recursive partitioning.
+ * Decision tree for regression. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This
+ * process is repeated on each derived subset in a recursive manner called recursive partitioning.
* <p>
- * Classification and Regression Tree techniques have a number of advantages over many of those
- * alternative techniques.
+ * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
* <dl>
* <dt>Simple to understand and interpret.</dt>
- * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This
- * simplicity is useful not only for purposes of rapid classification of new observations, but can
- * also often yield a much simpler "model" for explaining why observations are classified or
- * predicted in a particular manner.</dd>
+ * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid
+ * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in
+ * a particular manner.</dd>
* <dt>Able to handle both numerical and categorical data.</dt>
- * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of
- * variable.</dd>
+ * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd>
* <dt>Tree methods are nonparametric and nonlinear.</dt>
- * <dd>The final results of using tree methods for classification or regression can be summarized in
- * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no
- * implicit assumption that the underlying relationships between the predictor variables and the
- * dependent variable are linear, follow some specific non-linear link function, or that they are
- * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks,
- * where there is often little a priori knowledge nor any coherent set of theories or predictions
- * regarding which variables are related and how. In those types of data analytics, tree methods can
- * often reveal simple relationships between just a few variables that could have easily gone
- * unnoticed using other analytic techniques.</dd>
+ * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then
+ * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the
+ * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are
+ * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions
+ * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a
+ * few variables that could have easily gone unnoticed using other analytic techniques.</dd>
* </dl>
- * One major problem with classification and regression trees is their high variance. Often a small
- * change in the data can result in a very different series of splits, making interpretation
- * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause
- * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation
- * of trees is the lack of smoothness of the prediction surface.
+ * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different
+ * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting.
+ * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface.
* <p>
- * Some techniques such as bagging, boosting, and random forest use more than one decision tree for
- * their analysis.
+ * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
*
* @see GradientTreeBoost
* @see RandomForest
*/
-public final class RegressionTree implements Regression<double[]> {
+public final class RegressionTree implements Regression<Vector> {
/**
* The attributes of independent variable.
*/
private final Attribute[] _attributes;
private final boolean _hasNumericType;
/**
- * Variable importance. Every time a split of a node is made on variable the impurity criterion
- * for the two descendant nodes is less than the parent node. Adding up the decreases for each
- * individual variable over the tree gives a simple measure of variable importance.
+ * Variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendant nodes is less than the
+ * parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
*/
private final double[] _importance;
/**
@@ -120,8 +116,7 @@ public final class RegressionTree implements Regression<double[]> {
*/
private final int _maxDepth;
/**
- * The number of instances in a node below which the tree will not split, setting S = 5
- * generally gives good results.
+ * The number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
*/
private final int _minSplit;
/**
@@ -133,19 +128,17 @@ public final class RegressionTree implements Regression<double[]> {
*/
private final int _numVars;
/**
- * The index of training values in ascending order. Note that only numeric attributes will be
- * sorted.
+ * The index of training values in ascending order. Note that only numeric attributes will be sorted.
*/
- private final int[][] _order;
+ private final ColumnMajorIntMatrix _order;
- private final Random _rnd;
+ private final PRNG _rnd;
private final NodeOutput _nodeOutput;
/**
- * An interface to calculate node output. Note that samples[i] is the number of sampling of
- * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible
- * because of sampling with replacement.
+ * An interface to calculate node output. Note that samples[i] is the number of sampling of dataset[i]. 0 means that the datum is not included and
+ * values of greater than 1 are possible because of sampling with replacement.
*/
public interface NodeOutput {
/**
@@ -205,22 +198,30 @@ public final class RegressionTree implements Regression<double[]> {
this.output = output;
}
+ private boolean isLeaf() {
+ return trueChild == null && falseChild == null;
+ }
+
+ @VisibleForTesting
+ public double predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
/**
* Evaluate the regression tree over an instance.
*/
- public double predict(final double[] x) {
+ public double predict(@Nonnull final Vector x) {
if (trueChild == null && falseChild == null) {
return output;
} else {
if (splitFeatureType == AttributeType.NOMINAL) {
- // REVIEWME if(Math.equals(x[splitFeature], splitValue)) {
- if (x[splitFeature] == splitValue) {
+ if (x.get(splitFeature, Double.NaN) == splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
}
} else if (splitFeatureType == AttributeType.NUMERIC) {
- if (x[splitFeature] <= splitValue) {
+ if (x.get(splitFeature, Double.NaN) <= splitValue) {
return trueChild.predict(x);
} else {
return falseChild.predict(x);
@@ -283,99 +284,58 @@ public final class RegressionTree implements Regression<double[]> {
}
}
- public int opCodegen(final List<String> scripts, int depth) {
- int selfDepth = 0;
- final StringBuilder buf = new StringBuilder();
- if (trueChild == null && falseChild == null) {
- buf.append("push ").append(output);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("goto last");
- scripts.add(buf.toString());
- selfDepth += 2;
- } else {
- if (splitFeatureType == AttributeType.NOMINAL) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifeq ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else if (splitFeatureType == AttributeType.NUMERIC) {
- buf.append("push ").append("x[").append(splitFeature).append("]");
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("push ").append(splitValue);
- scripts.add(buf.toString());
- buf.setLength(0);
- buf.append("ifle ");
- scripts.add(buf.toString());
- depth += 3;
- selfDepth += 3;
- int trueDepth = trueChild.opCodegen(scripts, depth);
- selfDepth += trueDepth;
- scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth));
- int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth);
- selfDepth += falseDepth;
- } else {
- throw new IllegalStateException("Unsupported attribute type: "
- + splitFeatureType);
- }
- }
- return selfDepth;
- }
-
@Override
public void writeExternal(ObjectOutput out) throws IOException {
- out.writeDouble(output);
out.writeInt(splitFeature);
if (splitFeatureType == null) {
- out.writeInt(-1);
+ out.writeByte(-1);
} else {
- out.writeInt(splitFeatureType.getTypeId());
+ out.writeByte(splitFeatureType.getTypeId());
}
out.writeDouble(splitValue);
- if (trueChild == null) {
- out.writeBoolean(false);
- } else {
+
+ if (isLeaf()) {
out.writeBoolean(true);
- trueChild.writeExternal(out);
- }
- if (falseChild == null) {
- out.writeBoolean(false);
+ out.writeDouble(output);
} else {
- out.writeBoolean(true);
- falseChild.writeExternal(out);
+ out.writeBoolean(false);
+ if (trueChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ trueChild.writeExternal(out);
+ }
+ if (falseChild == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ falseChild.writeExternal(out);
+ }
}
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this.output = in.readDouble();
this.splitFeature = in.readInt();
- int typeId = in.readInt();
+ byte typeId = in.readByte();
if (typeId == -1) {
this.splitFeatureType = null;
} else {
this.splitFeatureType = AttributeType.resolve(typeId);
}
this.splitValue = in.readDouble();
- if (in.readBoolean()) {
- this.trueChild = new Node();
- trueChild.readExternal(in);
- }
- if (in.readBoolean()) {
- this.falseChild = new Node();
- falseChild.readExternal(in);
+
+ if (in.readBoolean()) {// isLeaf()
+ this.output = in.readDouble();
+ } else {
+ if (in.readBoolean()) {
+ this.trueChild = new Node();
+ trueChild.readExternal(in);
+ }
+ if (in.readBoolean()) {
+ this.falseChild = new Node();
+ falseChild.readExternal(in);
+ }
}
}
}
@@ -406,7 +366,7 @@ public final class RegressionTree implements Regression<double[]> {
/**
* Training dataset.
*/
- final double[][] x;
+ final Matrix x;
/**
* Training data response value.
*/
@@ -419,7 +379,7 @@ public final class RegressionTree implements Regression<double[]> {
/**
* Constructor.
*/
- public TrainNode(Node node, double[][] x, double[] y, int[] bags, int depth) {
+ public TrainNode(Node node, Matrix x, double[] y, int[] bags, int depth) {
this.node = node;
this.x = x;
this.y = y;
@@ -452,8 +412,7 @@ public final class RegressionTree implements Regression<double[]> {
}
/**
- * Finds the best attribute to split on at the current node. Returns true if a split exists
- * to reduce squared error, false otherwise.
+ * Finds the best attribute to split on at the current node. Returns true if a split exists to reduce squared error, false otherwise.
*/
public boolean findBestSplit() {
// avoid split if tree depth is larger than threshold
@@ -467,22 +426,14 @@ public final class RegressionTree implements Regression<double[]> {
}
final double sum = node.output * numSamples;
- final int p = _attributes.length;
- final int[] variables = new int[p];
- for (int i = 0; i < p; i++) {
- variables[i] = i;
- }
- if (_numVars < p) {
- SmileExtUtils.shuffle(variables, _rnd);
- }
// Loop through features and compute the reduction of squared error,
// which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2
- final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length)
+ final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows())
: null;
- for (int j = 0; j < _numVars; j++) {
- Node split = findBestSplit(numSamples, sum, variables[j], samples);
+ for (int varJ : variableIndex(x, bags)) {
+ final Node split = findBestSplit(numSamples, sum, varJ, samples);
if (split.splitScore > node.splitScore) {
node.splitFeature = split.splitFeature;
node.splitFeatureType = split.splitFeatureType;
@@ -496,6 +447,31 @@ public final class RegressionTree implements Regression<double[]> {
return node.splitFeature != -1;
}
+ private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) {
+ final int[] variableIndex;
+ if (x.isSparse()) {
+ final IntSet cols = new IntArraySet(_numVars);
+ final VectorProcedure proc = new VectorProcedure() {
+ public void apply(int col, double value) {
+ cols.add(col);
+ }
+ };
+ for (final int row : bags) {
+ x.eachNonNullInRow(row, proc);
+ }
+ variableIndex = cols.toArray(false);
+ } else {
+ variableIndex = MathUtils.permutation(_attributes.length);
+ }
+
+ if (_numVars < variableIndex.length) {
+ SmileExtUtils.shuffle(variableIndex, _rnd);
+ return Arrays.copyOf(variableIndex, _numVars);
+
+ }
+ return variableIndex;
+ }
+
/**
* Finds the best split cutoff for attribute j at the current node.
*
@@ -517,7 +493,11 @@ public final class RegressionTree implements Regression<double[]> {
// For each true feature of this datum increment the
// sufficient statistics for the "true" branch to evaluate
// splitting on this feature.
- int index = (int) x[i][j];
+ final double v = x.get(i, j, Double.NaN);
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ int index = (int) v;
trueSum[index] += y[i];
++trueCount[index];
}
@@ -548,28 +528,38 @@ public final class RegressionTree implements Regression<double[]> {
}
}
} else if (_attributes[j].type == AttributeType.NUMERIC) {
- double trueSum = 0.0;
- int trueCount = 0;
- double prevx = Double.NaN;
-
- for (int i : _order[j]) {
- final int sample = samples[i];
- if (sample > 0) {
- if (Double.isNaN(prevx) || x[i][j] == prevx) {
- prevx = x[i][j];
- trueSum += sample * y[i];
+
+ _order.eachNonNullInColumn(j, new VectorProcedure() {
+ double trueSum = 0.0;
+ int trueCount = 0;
+ double prevx = Double.NaN;
+
+ public void apply(final int row, final int i) {
+ final int sample = samples[i];
+ if (sample == 0) {
+ return;
+ }
+ final double x_ij = x.get(i, j, Double.NaN);
+ if (Double.isNaN(x_ij)) {
+ return;
+ }
+ final double y_i = y[i];
+
+ if (Double.isNaN(prevx) || x_ij == prevx) {
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- continue;
+ return;
}
final double falseCount = n - trueCount;
// If either side is empty, skip this feature.
if (trueCount < _minSplit || falseCount < _minSplit) {
- prevx = x[i][j];
- trueSum += sample * y[i];
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- continue;
+ return;
}
// compute penalized means
@@ -586,17 +576,18 @@ public final class RegressionTree implements Regression<double[]> {
// new best split
split.splitFeature = j;
split.splitFeatureType = AttributeType.NUMERIC;
- split.splitValue = (x[i][j] + prevx) / 2;
+ split.splitValue = (x_ij + prevx) / 2;
split.splitScore = gain;
split.trueChildOutput = trueMean;
split.falseChildOutput = falseMean;
}
- prevx = x[i][j];
- trueSum += sample * y[i];
+ prevx = x_ij;
+ trueSum += sample * y_i;
trueCount += sample;
- }
- }
+ }//apply
+ });
+
} else {
throw new IllegalStateException("Unsupported attribute type: "
+ _attributes[j].type);
@@ -672,7 +663,7 @@ public final class RegressionTree implements Regression<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] == splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) == splitValue) {
trueBags.add(index);
tc++;
} else {
@@ -684,7 +675,7 @@ public final class RegressionTree implements Regression<double[]> {
final double splitValue = node.splitValue;
for (int i = 0, size = bags.length; i < size; i++) {
final int index = bags[i];
- if (x[index][splitFeature] <= splitValue) {
+ if (x.get(index, splitFeature, Double.NaN) <= splitValue) {
trueBags.add(index);
tc++;
} else {
@@ -700,20 +691,19 @@ public final class RegressionTree implements Regression<double[]> {
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int maxLeafs) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int maxLeafs) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null);
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int maxLeafs, @Nullable smile.math.Random rand) {
- this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int maxLeafs, @Nullable PRNG rand) {
+ this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand);
}
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
- @Nullable smile.math.Random rand) {
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+ @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, @Nullable PRNG rand) {
this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand);
}
@@ -723,24 +713,22 @@ public final class RegressionTree implements Regression<double[]> {
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
- * @param numVars the number of input variables to pick to split on at each node. It seems that
- * dim/3 give generally good performance, where dim is the number of variables.
+ * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim
+ * is the number of variables.
* @param maxLeafs the maximum number of leaf nodes in the tree.
- * @param minSplits number of instances in a node below which the tree will not split, setting S
- * = 5 generally gives good results.
- * @param order the index of training values in ascending order. Note that only numeric
- * attributes need be sorted.
+ * @param minSplits number of instances in a node below which the tree will not split, setting S = 5 generally gives good results.
+ * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted.
* @param bags the sample set of instances for stochastic learning.
* @param output An interface to calculate node output.
*/
- public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x,
- @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits,
- int minLeafSize, @Nullable int[][] order, @Nullable int[] bags,
- @Nullable NodeOutput output, @Nullable smile.math.Random rand) {
+ public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y,
+ int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize,
+ @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags,
+ @Nullable NodeOutput output, @Nullable PRNG rand) {
checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize);
this._attributes = SmileExtUtils.attributeTypes(attributes, x);
- if (_attributes.length != x[0].length) {
+ if (_attributes.length != x.numColumns()) {
throw new IllegalArgumentException("-attrs option is invliad: "
+ Arrays.toString(attributes));
}
@@ -752,7 +740,7 @@ public final class RegressionTree implements Regression<double[]> {
this._minLeafSize = minLeafSize;
this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order;
this._importance = new double[_attributes.length];
- this._rnd = (rand == null) ? new smile.math.Random() : rand;
+ this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand;
this._nodeOutput = output;
int n = 0;
@@ -803,13 +791,13 @@ public final class RegressionTree implements Regression<double[]> {
}
}
- private static void checkArgument(@Nonnull double[][] x, @Nonnull double[] y, int numVars,
+ private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
- if (x.length != y.length) {
+ if (x.numRows() != y.length) {
throw new IllegalArgumentException(String.format(
- "The sizes of X and Y don't match: %d != %d", x.length, y.length));
+ "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length));
}
- if (numVars <= 0 || numVars > x[0].length) {
+ if (numVars <= 0 || numVars > x.numColumns()) {
throw new IllegalArgumentException(
"Invalid number of variables to split on at a node of the tree: " + numVars);
}
@@ -830,10 +818,8 @@ public final class RegressionTree implements Regression<double[]> {
}
/**
- * Returns the variable importance. Every time a split of a node is made on variable the
- * impurity criterion for the two descendent nodes is less than the parent node. Adding up the
- * decreases for each individual variable over the tree gives a simple measure of variable
- * importance.
+ * Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less
+ * than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance.
*
* @return the variable importance
*/
@@ -841,8 +827,13 @@ public final class RegressionTree implements Regression<double[]> {
return _importance;
}
+ @VisibleForTesting
+ public double predict(@Nonnull final double[] x) {
+ return predict(new DenseVector(x));
+ }
+
@Override
- public double predict(double[] x) {
+ public double predict(@Nonnull final Vector x) {
return _root.predict(x);
}
@@ -852,14 +843,6 @@ public final class RegressionTree implements Regression<double[]> {
return buf.toString();
}
- public String predictOpCodegen(@Nonnull String sep) {
- List<String> opslist = new ArrayList<String>();
- _root.opCodegen(opslist, 0);
- opslist.add("call end");
- String scripts = StringUtils.concat(opslist, sep);
- return scripts;
- }
-
@Nonnull
public byte[] predictSerCodegen(boolean compress) throws HiveException {
try {
[04/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
new file mode 100644
index 0000000..6a7f39f
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with float hashing
+ *
+ * @see http://en.wikipedia.org/wiki/float_hashing
+ */
+public final class Long2FloatOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected float defaultReturnValue = 0.f;
+
+ protected long[] _keys;
+ protected float[] _values;
+ protected byte[] _states;
+
+ protected Long2FloatOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new long[actualSize];
+ this._values = new float[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Long2FloatOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Long2FloatOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Long2FloatOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(float v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final long key) {
+ return _findKey(key) >= 0;
+ }
+
+ /**
+ * @return defaultReturnValue if not found
+ */
+ public float get(final long key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public float get(final long key, final float defaultValue) {
+ final int i = _findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public float _get(final int index) {
+ if (index < 0) {
+ return defaultReturnValue;
+ }
+ return _values[index];
+ }
+
+ public float put(final long key, final float value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// float hashing
+ if (keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final long key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return -1 if not found
+ */
+ public int _findKey(final long key) {
+ final long[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public float remove(final long key) {
+ final long[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final long[] newkeys = new long[newCapacity];
+ final float[] newValues = new float[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ long k = _keys[i];
+ float v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final long key) {
+ return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeLong(i.getKey());
+ out.writeFloat(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final long[] keys = new long[keylen];
+ final float[] values = new float[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ long k = in.readLong();
+ float v = in.readFloat();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public long getKey();
+
+ public float getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public long getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public float getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
new file mode 100644
index 0000000..51b8f12
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public final class Long2IntOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected int defaultReturnValue = -1;
+
+ protected long[] _keys;
+ protected int[] _values;
+ protected byte[] _states;
+
+ protected Long2IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new long[actualSize];
+ this._values = new int[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Long2IntOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Long2IntOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Long2IntOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(int v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final long key) {
+ return _findKey(key) >= 0;
+ }
+
+ /**
+ * @return defaultReturnValue if not found
+ */
+ public int get(final long key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public int get(final long key, final int defaultValue) {
+ final int i = _findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public int _get(final int index) {
+ if (index < 0) {
+ return defaultReturnValue;
+ }
+ return _values[index];
+ }
+
+ public int put(final long key, final int value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ public int incr(final long key, final int delta) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] += delta;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] += delta;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] += delta;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final long key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return -1 if not found
+ */
+ public int _findKey(final long key) {
+ final long[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public int remove(final long key) {
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final long[] newkeys = new long[newCapacity];
+ final int[] newValues = new int[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ long k = _keys[i];
+ int v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final long key) {
+ return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeLong(i.getKey());
+ out.writeInt(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final long[] keys = new long[keylen];
+ final int[] values = new int[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ long k = in.readLong();
+ int v = in.readInt();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public long getKey();
+
+ public int getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public long getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public int getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
new file mode 100644
index 0000000..152447a
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+//
+// Copyright (C) 2010 catchpole.net
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.MathUtils;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An optimized Hashed Map implementation.
+ * <p/>
+ * <p>
+ * This Hashmap does not allow nulls to be used as keys or values.
+ * <p/>
+ * <p>
+ * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those
+ * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking
+ * operation.
+ */
+public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
+ private K[] keys;
+ private V[] values;
+
+ // total number of entries in this table
+ private int size;
+ // number of bits for the value table (eg. 8 bits = 256 entries)
+ private int bits;
+ // the number of bits in each sweep zone.
+ private int sweepbits;
+ // the size of a sweep (2 to the power of sweepbits)
+ private int sweep;
+ // the sweepmask used to create sweep zone offsets
+ private int sweepmask;
+
+ public OpenHashMap() {}// for Externalizable
+
+ public OpenHashMap(int size) {
+ resize(MathUtils.bitsRequired(size < 256 ? 256 : size));
+ }
+
+ public V put(K key, V value) {
+ if (key == null) {
+ throw new NullPointerException(this.getClass().getName() + " key");
+ }
+
+ for (;;) {
+ int off = getBucketOffset(key);
+ int end = off + sweep;
+ for (; off < end; off++) {
+ K searchKey = keys[off];
+ if (searchKey == null) {
+ // insert
+ keys[off] = key;
+ size++;
+
+ V previous = values[off];
+ values[off] = value;
+ return previous;
+ } else if (compare(searchKey, key)) {
+ // replace
+ V previous = values[off];
+ values[off] = value;
+ return previous;
+ }
+ }
+ resize(this.bits + 1);
+ }
+ }
+
+ public V get(Object key) {
+ int off = getBucketOffset(key);
+ int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] != null && compare(keys[off], key)) {
+ return values[off];
+ }
+ }
+ return null;
+ }
+
+ public V remove(Object key) {
+ int off = getBucketOffset(key);
+ int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] != null && compare(keys[off], key)) {
+ keys[off] = null;
+ V previous = values[off];
+ values[off] = null;
+ size--;
+ return previous;
+ }
+ }
+ return null;
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public void putAll(Map<? extends K, ? extends V> m) {
+ for (K key : m.keySet()) {
+ put(key, m.get(key));
+ }
+ }
+
+ public boolean isEmpty() {
+ return size == 0;
+ }
+
+ public boolean containsKey(Object key) {
+ return get(key) != null;
+ }
+
+ public boolean containsValue(Object value) {
+ for (V v : values) {
+ if (v != null && compare(v, value)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public void clear() {
+ Arrays.fill(keys, null);
+ Arrays.fill(values, null);
+ size = 0;
+ }
+
+ public Set<K> keySet() {
+ Set<K> set = new HashSet<K>();
+ for (K key : keys) {
+ if (key != null) {
+ set.add(key);
+ }
+ }
+ return set;
+ }
+
+ public Collection<V> values() {
+ Collection<V> list = new ArrayList<V>();
+ for (V value : values) {
+ if (value != null) {
+ list.add(value);
+ }
+ }
+ return list;
+ }
+
+ public Set<Entry<K, V>> entrySet() {
+ Set<Entry<K, V>> set = new HashSet<Entry<K, V>>();
+ for (K key : keys) {
+ if (key != null) {
+ set.add(new MapEntry<K, V>(this, key));
+ }
+ }
+ return set;
+ }
+
+ private static final class MapEntry<K, V> implements Map.Entry<K, V> {
+ private final Map<K, V> map;
+ private final K key;
+
+ public MapEntry(Map<K, V> map, K key) {
+ this.map = map;
+ this.key = key;
+ }
+
+ public K getKey() {
+ return key;
+ }
+
+ public V getValue() {
+ return map.get(key);
+ }
+
+ public V setValue(V value) {
+ return map.put(key, value);
+ }
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // remember the number of bits
+ out.writeInt(this.bits);
+ // remember the total number of entries
+ out.writeInt(this.size);
+ // write all entries
+ for (int x = 0; x < this.keys.length; x++) {
+ if (keys[x] != null) {
+ out.writeObject(keys[x]);
+ out.writeObject(values[x]);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // resize to old bit size
+ int bitSize = in.readInt();
+ if (bitSize != bits) {
+ resize(bitSize);
+ }
+ // read all entries
+ int size = in.readInt();
+ for (int x = 0; x < size; x++) {
+ this.put((K) in.readObject(), (V) in.readObject());
+ }
+ }
+
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName() + ' ' + this.size;
+ }
+
+ @SuppressWarnings("unchecked")
+ private void resize(int bits) {
+ this.bits = bits;
+ this.sweepbits = bits / 4;
+ this.sweep = MathUtils.powerOf(2, sweepbits) * 4;
+ this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits;
+
+ // remember old values so we can recreate the entries
+ K[] existingKeys = this.keys;
+ V[] existingValues = this.values;
+
+ // create the arrays
+ this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep];
+ this.keys = (K[]) new Object[values.length];
+ this.size = 0;
+
+ // re-add the previous entries if resizing
+ if (existingKeys != null) {
+ for (int x = 0; x < existingKeys.length; x++) {
+ if (existingKeys[x] != null) {
+ put(existingKeys[x], existingValues[x]);
+ }
+ }
+ }
+ }
+
+ private int getBucketOffset(Object key) {
+ return (key.hashCode() << this.sweepbits) & this.sweepmask;
+ }
+
+ private static boolean compare(final Object v1, final Object v2) {
+ return v1 == v2 || v1.equals(v2);
+ }
+
+ public IMapIterator<K, V> entries() {
+ return new MapIterator();
+ }
+
+ private final class MapIterator implements IMapIterator<K, V> {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < keys.length && keys[index] == null) {
+ index++;
+ }
+ return index;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return nextEntry < keys.length;
+ }
+
+ @Override
+ public int next() {
+ free(lastEntry);
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ @Override
+ public K getKey() {
+ return keys[lastEntry];
+ }
+
+ @Override
+ public V getValue() {
+ return values[lastEntry];
+ }
+
+ @Override
+ public <T extends Copyable<V>> void getValue(T probe) {
+ probe.copyFrom(getValue());
+ }
+
+ private void free(int index) {
+ if (index >= 0) {
+ keys[index] = null;
+ values[index] = null;
+ }
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
new file mode 100644
index 0000000..7fec9b0
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+import java.util.HashMap;
+
+import javax.annotation.Nonnull;
+
+/**
+ * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
+ */
+public final class OpenHashTable<K, V> implements Externalizable {
+
+ public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ protected/* final */float _loadFactor;
+ protected/* final */float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+
+ protected K[] _keys;
+ protected V[] _values;
+ protected byte[] _states;
+
+ public OpenHashTable() {} // for Externalizable
+
+ public OpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
+ }
+
+ @SuppressWarnings("unchecked")
+ public OpenHashTable(int size, float loadFactor, float growFactor) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = Primes.findLeastPrimeNumber(size);
+ this._keys = (K[]) new Object[actualSize];
+ this._values = (V[]) new Object[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = Math.round(actualSize * _loadFactor);
+ }
+
+ public OpenHashTable(@Nonnull K[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) {
+ this._used = used;
+ this._threshold = keys.length;
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public Object[] getKeys() {
+ return _keys;
+ }
+
+ public Object[] getValues() {
+ return _values;
+ }
+
+ public byte[] getStates() {
+ return _states;
+ }
+
+ public boolean containsKey(final K key) {
+ return findKey(key) >= 0;
+ }
+
+ public V get(final K key) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return null;
+ }
+ return _values[i];
+ }
+
+ public V put(final K key, final V value) {
+ int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ K[] keys = _keys;
+ V[] values = _values;
+ byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {
+ if (equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return null;
+ }
+
+ private static boolean equals(final Object k1, final Object k2) {
+ return k1 == k2 || k1.equals(k2);
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, K key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && equals(_keys[index], key)) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// filled enough
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final K key) {
+ K[] keys = _keys;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public V remove(final K key) {
+ K[] keys = _keys;
+ V[] values = _values;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return null;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return null;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator<K, V> entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ final StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ final IMapIterator<K, V> i = entries();
+ while (i.next() != -1) {
+ String key = i.getKey().toString();
+ buf.append(key);
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final K[] newkeys = (K[]) new Object[newCapacity];
+ final V[] newValues = (V[]) new Object[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ K k = _keys[i];
+ V v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newStates[keyIdx] = FULL;
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final Object key) {
+ int hash = key.hashCode();
+ return hash & 0x7fffffff;
+ }
+
+ private final class MapIterator implements IMapIterator<K, V> {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = nextEntry;
+ this.nextEntry = nextEntry(nextEntry + 1);
+ return curEntry;
+ }
+
+ public K getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public V getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+
+ @Override
+ public <T extends Copyable<V>> void getValue(T probe) {
+ probe.copyFrom(getValue());
+ }
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeFloat(_loadFactor);
+ out.writeFloat(_growFactor);
+ out.writeInt(_used);
+
+ final int size = _keys.length;
+ out.writeInt(size);
+
+ for (int i = 0; i < size; i++) {
+ out.writeObject(_keys[i]);
+ out.writeObject(_values[i]);
+ out.writeByte(_states[i]);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._loadFactor = in.readFloat();
+ this._growFactor = in.readFloat();
+ this._used = in.readInt();
+
+ final int size = in.readInt();
+ final Object[] keys = new Object[size];
+ final Object[] values = new Object[size];
+ final byte[] states = new byte[size];
+ for (int i = 0; i < size; i++) {
+ keys[i] = in.readObject();
+ values[i] = in.readObject();
+ states[i] = in.readByte();
+ }
+ this._threshold = size;
+ this._keys = (K[]) keys;
+ this._values = (V[]) values;
+ this._states = states;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
new file mode 100644
index 0000000..06b6a15
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.sets;
+
+import hivemall.utils.lang.ArrayUtils;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+public final class IntArraySet implements IntSet {
+
+ @Nonnull
+ private int[] mKeys;
+ private int mSize;
+
+ public IntArraySet() {
+ this(0);
+ }
+
+ public IntArraySet(int initSize) {
+ this.mKeys = new int[initSize];
+ this.mSize = 0;
+ }
+
+ @Override
+ public boolean add(final int k) {
+ final int i = Arrays.binarySearch(mKeys, 0, mSize, k);
+ if (i >= 0) {
+ return false;
+ }
+ mKeys = ArrayUtils.insert(mKeys, mSize, ~i, k);
+ mSize++;
+ return true;
+ }
+
+ @Override
+ public boolean remove(final int k) {
+ final int i = Arrays.binarySearch(mKeys, 0, mSize, k);
+ if (i < 0) {
+ return false;
+ }
+ System.arraycopy(mKeys, i + 1, mKeys, i, mSize - (i + 1));
+ mSize--;
+ return true;
+ }
+
+ @Override
+ public boolean contains(final int k) {
+ return Arrays.binarySearch(mKeys, 0, mSize, k) >= 0;
+ }
+
+ @Override
+ public int size() {
+ return mSize;
+ }
+
+ @Override
+ public void clear() {
+ this.mSize = 0;
+ }
+
+ @Override
+ public int[] toArray(final boolean copy) {
+ if (copy == false && mKeys.length == mSize) {
+ return mKeys;
+ }
+
+ return Arrays.copyOf(mKeys, mSize);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntSet.java b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
new file mode 100644
index 0000000..398955c
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.sets;
+
+import javax.annotation.Nonnull;
+
+public interface IntSet {
+
+ public boolean add(int k);
+
+ public boolean remove(int k);
+
+ public boolean contains(int k);
+
+ public int size();
+
+ public void clear();
+
+ @Nonnull
+ public int[] toArray(boolean copy);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5423c9d..b3a2de1 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -56,6 +56,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
@@ -256,6 +258,10 @@ public final class HiveUtils {
&& isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
}
+ public static boolean isConstString(@Nonnull final ObjectInspector oi) {
+ return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi);
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
@@ -308,20 +314,43 @@ public final class HiveUtils {
}
}
- public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) {
+ public static boolean isIntTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
+ return false;
+ }
+ return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.INT;
+ }
+
+ public static boolean isFloatingPointTypeInfo(@Nonnull TypeInfo typeInfo) {
if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
return false;
}
switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) {
- case STRING:
+ case DOUBLE:
+ case FLOAT:
return true;
default:
return false;
}
}
- public static boolean isConstString(@Nonnull final ObjectInspector oi) {
- return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi);
+ public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
+ return false;
+ }
+ return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.STRING;
+ }
+
+ public static boolean isListTypeInfo(@Nonnull TypeInfo typeInfo) {
+ return typeInfo.getCategory() == Category.LIST;
+ }
+
+ public static boolean isFloatingPointListTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != Category.LIST) {
+ return false;
+ }
+ TypeInfo elemTypeInfo = ((ListTypeInfo) typeInfo).getListElementTypeInfo();
+ return isFloatingPointTypeInfo(elemTypeInfo);
}
@Nonnull
@@ -387,6 +416,38 @@ public final class HiveUtils {
return ary;
}
+ @Nullable
+ public static double[] getConstDoubleArray(@Nonnull final ObjectInspector oi)
+ throws UDFArgumentException {
+ if (!ObjectInspectorUtils.isConstantObjectInspector(oi)) {
+ throw new UDFArgumentException("argument must be a constant value: "
+ + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+ }
+ ConstantObjectInspector constOI = (ConstantObjectInspector) oi;
+ if (constOI.getCategory() != Category.LIST) {
+ throw new UDFArgumentException("argument must be an array: "
+ + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+ }
+ StandardConstantListObjectInspector listOI = (StandardConstantListObjectInspector) constOI;
+ PrimitiveObjectInspector elemOI = HiveUtils.asDoubleCompatibleOI(listOI.getListElementObjectInspector());
+
+ final List<?> lst = listOI.getWritableConstantValue();
+ if (lst == null) {
+ return null;
+ }
+ final int size = lst.size();
+ final double[] ary = new double[size];
+ for (int i = 0; i < size; i++) {
+ Object o = lst.get(i);
+ if (o == null) {
+ ary[i] = Double.NaN;
+ } else {
+ ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
+ }
+ }
+ return ary;
+ }
+
public static String getConstString(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
if (!isStringOI(oi)) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 24ed7fc..e8e337d 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -170,6 +170,24 @@ public final class ArrayUtils {
arr[j] = tmp;
}
+ public static void swap(@Nonnull final long[] arr, final int i, final int j) {
+ long tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
+ public static void swap(@Nonnull final float[] arr, final int i, final int j) {
+ float tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
+ public static void swap(@Nonnull final double[] arr, final int i, final int j) {
+ double tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
@Nullable
public static Object[] subarray(@Nullable final Object[] array, int startIndexInclusive,
int endIndexExclusive) {
@@ -198,7 +216,7 @@ public final class ArrayUtils {
}
}
- public static int indexOf(@Nonnull final int[] array, final int valueToFind,
+ public static int indexOf(@Nullable final int[] array, final int valueToFind,
final int startIndex, final int endIndex) {
if (array == null) {
return INDEX_NOT_FOUND;
@@ -215,6 +233,36 @@ public final class ArrayUtils {
return INDEX_NOT_FOUND;
}
+ public static int lastIndexOf(@Nullable final int[] array, final int valueToFind, int startIndex) {
+ if (array == null) {
+ return INDEX_NOT_FOUND;
+ }
+ return lastIndexOf(array, valueToFind, startIndex, array.length);
+ }
+
+ /**
+ * @param startIndex inclusive start index
+ * @param endIndex exclusive end index
+ */
+ public static int lastIndexOf(@Nullable final int[] array, final int valueToFind,
+ int startIndex, int endIndex) {
+ if (array == null) {
+ return INDEX_NOT_FOUND;
+ }
+ if (startIndex < 0) {
+ throw new IllegalArgumentException("startIndex out of bound: " + startIndex);
+ }
+ if (endIndex >= array.length) {
+ throw new IllegalArgumentException("endIndex out of bound: " + endIndex);
+ }
+ for (int i = endIndex - 1; i >= startIndex; i--) {
+ if (valueToFind == array[i]) {
+ return i;
+ }
+ }
+ return INDEX_NOT_FOUND;
+ }
+
@Nonnull
public static byte[] copyOf(@Nonnull final byte[] original, final int newLength) {
final byte[] copy = new byte[newLength];
@@ -249,6 +297,17 @@ public final class ArrayUtils {
}
@Nonnull
+ public static float[] append(@Nonnull float[] array, final int currentSize, final float element) {
+ if (currentSize + 1 > array.length) {
+ float[] newArray = new float[currentSize * 2];
+ System.arraycopy(array, 0, newArray, 0, currentSize);
+ array = newArray;
+ }
+ array[currentSize] = element;
+ return array;
+ }
+
+ @Nonnull
public static double[] append(@Nonnull double[] array, final int currentSize,
final double element) {
if (currentSize + 1 > array.length) {
@@ -268,7 +327,22 @@ public final class ArrayUtils {
array[index] = element;
return array;
}
- int[] newArray = new int[currentSize * 2];
+ final int[] newArray = new int[currentSize * 2];
+ System.arraycopy(array, 0, newArray, 0, index);
+ newArray[index] = element;
+ System.arraycopy(array, index, newArray, index + 1, array.length - index);
+ return newArray;
+ }
+
+ @Nonnull
+ public static float[] insert(@Nonnull final float[] array, final int currentSize,
+ final int index, final float element) {
+ if (currentSize + 1 <= array.length) {
+ System.arraycopy(array, index, array, index + 1, currentSize - index);
+ array[index] = element;
+ return array;
+ }
+ final float[] newArray = new float[currentSize * 2];
System.arraycopy(array, 0, newArray, 0, index);
newArray[index] = element;
System.arraycopy(array, index, newArray, index + 1, array.length - index);
@@ -283,7 +357,7 @@ public final class ArrayUtils {
array[index] = element;
return array;
}
- double[] newArray = new double[currentSize * 2];
+ final double[] newArray = new double[currentSize * 2];
System.arraycopy(array, 0, newArray, 0, index);
newArray[index] = element;
System.arraycopy(array, index, newArray, index + 1, array.length - index);
@@ -314,4 +388,331 @@ public final class ArrayUtils {
return true;
}
+ public static void copy(@Nonnull final float[] src, @Nonnull final double[] dst) {
+ final int size = Math.min(src.length, dst.length);
+ for (int i = 0; i < size; i++) {
+ dst[i] = src[i];
+ }
+ }
+
+ public static void sort(final long[] arr, final double[] brr) {
+ sort(arr, brr, arr.length);
+ }
+
+ public static void sort(final long[] arr, final double[] brr, final int n) {
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ long a;
+ double b;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final double[] crr) {
+ sort(arr, brr, crr, arr.length);
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final double[] crr, final int n) {
+ Preconditions.checkArgument(arr.length >= n);
+ Preconditions.checkArgument(brr.length >= n);
+ Preconditions.checkArgument(crr.length >= n);
+
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ int a, b;
+ double c;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ c = crr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ crr[i + 1] = crr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ crr[i + 1] = c;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ swap(crr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ swap(crr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ swap(crr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ swap(crr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ c = crr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ swap(crr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ crr[l + 1] = crr[j];
+ crr[j] = c;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final float[] crr) {
+ sort(arr, brr, crr, arr.length);
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final float[] crr, final int n) {
+ Preconditions.checkArgument(arr.length >= n);
+ Preconditions.checkArgument(brr.length >= n);
+ Preconditions.checkArgument(crr.length >= n);
+
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ int a, b;
+ float c;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ c = crr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ crr[i + 1] = crr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ crr[i + 1] = c;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ swap(crr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ swap(crr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ swap(crr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ swap(crr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ c = crr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ swap(crr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ crr[l + 1] = crr[j];
+ crr[j] = c;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static int count(@Nonnull final int[] values, final int valueToFind) {
+ int cnt = 0;
+ for (int i = 0; i < values.length; i++) {
+ if (values[i] == valueToFind) {
+ cnt++;
+ }
+ }
+ return cnt;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index 8f018f0..31cd8a8 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -18,6 +18,8 @@
*/
package hivemall.utils.lang;
+import javax.annotation.Nonnull;
+
public final class Primitives {
public static final int INT_BYTES = Integer.SIZE / Byte.SIZE;
public static final int DOUBLE_BYTES = Double.SIZE / Byte.SIZE;
@@ -99,4 +101,30 @@ public final class Primitives {
return result;
}
+ public static long toLong(final int high, final int low) {
+ return ((long) high << 32) | ((long) low & 0xffffffffL);
+ }
+
+ public static int getHigh(final long key) {
+ return (int) (key >>> 32) & 0xffffffff;
+ }
+
+ public static int getLow(final long key) {
+ return (int) key & 0xffffffff;
+ }
+
+ @Nonnull
+ public static byte[] toBytes(long l) {
+ final byte[] retVal = new byte[8];
+ for (int i = 0; i < 8; i++) {
+ retVal[i] = (byte) l;
+ l >>= 8;
+ }
+ return retVal;
+ }
+
+ public static int hashCode(final long value) {
+ return (int) (value ^ (value >>> 32));
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 252ccf6..b71d165 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -36,6 +36,7 @@ package hivemall.utils.math;
import java.util.Random;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
public final class MathUtils {
@@ -250,6 +251,9 @@ public final class MathUtils {
}
public static boolean equals(@Nonnull final float value, final float expected, final float delta) {
+ if (Double.isNaN(value)) {
+ return false;
+ }
if (Math.abs(expected - value) > delta) {
return false;
}
@@ -258,19 +262,20 @@ public final class MathUtils {
public static boolean equals(@Nonnull final double value, final double expected,
final double delta) {
+ if (Double.isNaN(value)) {
+ return false;
+ }
if (Math.abs(expected - value) > delta) {
return false;
}
return true;
}
- public static boolean almostEquals(@Nonnull final float value, final float expected,
- final float delta) {
+ public static boolean almostEquals(@Nonnull final float value, final float expected) {
return equals(value, expected, 1E-15f);
}
- public static boolean almostEquals(@Nonnull final double value, final double expected,
- final double delta) {
+ public static boolean almostEquals(@Nonnull final double value, final double expected) {
return equals(value, expected, 1E-15d);
}
@@ -297,4 +302,13 @@ public final class MathUtils {
return 0; // 0 or NaN
}
+ @Nonnull
+ public static int[] permutation(@Nonnegative final int size) {
+ final int[] perm = new int[size];
+ for (int i = 0; i < size; i++) {
+ perm[i] = i;
+ }
+ return perm;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MatrixUtils.java b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
index 66d6e8c..a0e5fc7 100644
--- a/core/src/main/java/hivemall/utils/math/MatrixUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
@@ -18,7 +18,7 @@
*/
package hivemall.utils.math;
-import hivemall.utils.collections.DoubleArrayList;
+import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.lang.Preconditions;
import java.util.Arrays;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
new file mode 100644
index 0000000..f86a788
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.sampling;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+
+/**
+ * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n items.
+ *
+ * @link http://en.wikipedia.org/wiki/Reservoir_sampling
+ * @link http://portal.acm.org/citation.cfm?id=3165
+ */
+public final class IntReservoirSampler {
+
+ private final int[] samples;
+ private final int numSamples;
+ private int position;
+
+ private final Random rand;
+
+ public IntReservoirSampler(int sampleSize) {
+ if (sampleSize <= 0) {
+ throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize);
+ }
+ this.samples = new int[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public IntReservoirSampler(int sampleSize, long seed) {
+ this.samples = new int[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public IntReservoirSampler(int[] samples) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public IntReservoirSampler(int[] samples, long seed) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public int size() {
+ return position;
+ }
+
+ @Nonnull
+ public int[] getSample() {
+ if (position >= numSamples) {
+ return samples;
+ }
+ return Arrays.copyOf(samples, position);
+ }
+
+ public void add(final int item) {
+ if (position < numSamples) {// reservoir not yet full, just append
+ samples[position] = item;
+ } else {// find a item to replace
+ int replaceIndex = rand.nextInt(position + 1);
+ if (replaceIndex < numSamples) {
+ samples[replaceIndex] = item;
+ }
+ }
+ position++;
+ }
+
+ public void clear() {
+ Arrays.fill(samples, 0);
+ this.position = 0;
+ }
+}
[06/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/OpenHashTable.java
deleted file mode 100644
index 1a3dff7..0000000
--- a/core/src/main/java/hivemall/utils/collections/OpenHashTable.java
+++ /dev/null
@@ -1,412 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.Copyable;
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-import java.util.HashMap;
-
-import javax.annotation.Nonnull;
-
-/**
- * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
- */
-public final class OpenHashTable<K, V> implements Externalizable {
-
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
- public static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- protected/* final */float _loadFactor;
- protected/* final */float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
-
- protected K[] _keys;
- protected V[] _values;
- protected byte[] _states;
-
- public OpenHashTable() {} // for Externalizable
-
- public OpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
- }
-
- @SuppressWarnings("unchecked")
- public OpenHashTable(int size, float loadFactor, float growFactor) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = Primes.findLeastPrimeNumber(size);
- this._keys = (K[]) new Object[actualSize];
- this._values = (V[]) new Object[actualSize];
- this._states = new byte[actualSize];
- this._threshold = Math.round(actualSize * _loadFactor);
- }
-
- public OpenHashTable(@Nonnull K[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) {
- this._used = used;
- this._threshold = keys.length;
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public Object[] getKeys() {
- return _keys;
- }
-
- public Object[] getValues() {
- return _values;
- }
-
- public byte[] getStates() {
- return _states;
- }
-
- public boolean containsKey(final K key) {
- return findKey(key) >= 0;
- }
-
- public V get(final K key) {
- final int i = findKey(key);
- if (i < 0) {
- return null;
- }
- return _values[i];
- }
-
- public V put(final K key, final V value) {
- int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- K[] keys = _keys;
- V[] values = _values;
- byte[] states = _states;
-
- if (states[keyIdx] == FULL) {
- if (equals(keys[keyIdx], key)) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- return null;
- }
-
- private static boolean equals(final Object k1, final Object k2) {
- return k1 == k2 || k1.equals(k2);
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, K key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && equals(_keys[index], key)) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// filled enough
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected int findKey(final K key) {
- K[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public V remove(final K key) {
- K[] keys = _keys;
- V[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return null;
- }
- if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- }
- }
- return null;
- }
-
- public int size() {
- return _used;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- public IMapIterator<K, V> entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- final StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- final IMapIterator<K, V> i = entries();
- while (i.next() != -1) {
- String key = i.getKey().toString();
- buf.append(key);
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- protected void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- @SuppressWarnings("unchecked")
- private void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- final K[] newkeys = (K[]) new Object[newCapacity];
- final V[] newValues = (V[]) new Object[newCapacity];
- final byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
- used++;
- K k = _keys[i];
- V v = _values[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newStates[keyIdx] = FULL;
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(final Object key) {
- int hash = key.hashCode();
- return hash & 0x7fffffff;
- }
-
- private final class MapIterator implements IMapIterator<K, V> {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = nextEntry;
- this.nextEntry = nextEntry(nextEntry + 1);
- return curEntry;
- }
-
- public K getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public V getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
-
- @Override
- public <T extends Copyable<V>> void getValue(T probe) {
- probe.copyFrom(getValue());
- }
- }
-
- @Override
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeFloat(_loadFactor);
- out.writeFloat(_growFactor);
- out.writeInt(_used);
-
- final int size = _keys.length;
- out.writeInt(size);
-
- for (int i = 0; i < size; i++) {
- out.writeObject(_keys[i]);
- out.writeObject(_values[i]);
- out.writeByte(_states[i]);
- }
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._loadFactor = in.readFloat();
- this._growFactor = in.readFloat();
- this._used = in.readInt();
-
- final int size = in.readInt();
- final Object[] keys = new Object[size];
- final Object[] values = new Object[size];
- final byte[] states = new byte[size];
- for (int i = 0; i < size; i++) {
- keys[i] = in.readObject();
- values[i] = in.readObject();
- states[i] = in.readByte();
- }
- this._threshold = size;
- this._keys = (K[]) keys;
- this._values = (V[]) values;
- this._states = states;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java b/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java
deleted file mode 100644
index c4dbbb5..0000000
--- a/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java
+++ /dev/null
@@ -1,213 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ArrayUtils;
-import hivemall.utils.lang.Preconditions;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-
-public final class SparseDoubleArray implements DoubleArray {
- private static final long serialVersionUID = -2814248784231540118L;
-
- @Nonnull
- private int[] mKeys;
- @Nonnull
- private double[] mValues;
- private int mSize;
-
- public SparseDoubleArray() {
- this(10);
- }
-
- public SparseDoubleArray(int initialCapacity) {
- mKeys = new int[initialCapacity];
- mValues = new double[initialCapacity];
- mSize = 0;
- }
-
- private SparseDoubleArray(@Nonnull int[] mKeys, @Nonnull double[] mValues, int mSize) {
- this.mKeys = mKeys;
- this.mValues = mValues;
- this.mSize = mSize;
- }
-
- @Nonnull
- public SparseDoubleArray deepCopy() {
- int[] newKeys = new int[mSize];
- double[] newValues = new double[mSize];
- System.arraycopy(mKeys, 0, newKeys, 0, mSize);
- System.arraycopy(mValues, 0, newValues, 0, mSize);
- return new SparseDoubleArray(newKeys, newValues, mSize);
- }
-
- @Override
- public double get(int key) {
- return get(key, 0);
- }
-
- @Override
- public double get(int key, double valueIfKeyNotFound) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i < 0) {
- return valueIfKeyNotFound;
- } else {
- return mValues[i];
- }
- }
-
- public void delete(int key) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- removeAt(i);
- }
- }
-
- public void removeAt(int index) {
- System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1));
- System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1));
- mSize--;
- }
-
- @Override
- public void put(int key, double value) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- mValues[i] = value;
- } else {
- i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
- }
- }
-
- public void increment(int key, double value) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- mValues[i] += value;
- } else {
- i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
- }
- }
-
- @Override
- public int size() {
- return mSize;
- }
-
- @Override
- public int keyAt(int index) {
- return mKeys[index];
- }
-
- public double valueAt(int index) {
- return mValues[index];
- }
-
- public void setValueAt(int index, double value) {
- mValues[index] = value;
- }
-
- public int indexOfKey(int key) {
- return Arrays.binarySearch(mKeys, 0, mSize, key);
- }
-
- public int indexOfValue(double value) {
- for (int i = 0; i < mSize; i++) {
- if (mValues[i] == value) {
- return i;
- }
- }
- return -1;
- }
-
- public void clear() {
- clear(true);
- }
-
- public void clear(boolean zeroFill) {
- mSize = 0;
- if (zeroFill) {
- Arrays.fill(mKeys, 0);
- Arrays.fill(mValues, 0.d);
- }
- }
-
- public void append(int key, double value) {
- if (mSize != 0 && key <= mKeys[mSize - 1]) {
- put(key, value);
- return;
- }
- mKeys = ArrayUtils.append(mKeys, mSize, key);
- mValues = ArrayUtils.append(mValues, mSize, value);
- mSize++;
- }
-
- @Override
- public double[] toArray() {
- return toArray(true);
- }
-
- @Override
- public double[] toArray(boolean copy) {
- if (mSize == 0) {
- return new double[0];
- }
-
- int last = mKeys[mSize - 1];
- final double[] array = new double[last + 1];
- for (int i = 0; i < mSize; i++) {
- int k = mKeys[i];
- double v = mValues[i];
- Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k);
- array[k] = v;
- }
- return array;
- }
-
- @Override
- public String toString() {
- if (size() <= 0) {
- return "{}";
- }
-
- StringBuilder buffer = new StringBuilder(mSize * 28);
- buffer.append('{');
- for (int i = 0; i < mSize; i++) {
- if (i > 0) {
- buffer.append(", ");
- }
- int key = keyAt(i);
- buffer.append(key);
- buffer.append('=');
- double value = valueAt(i);
- buffer.append(value);
- }
- buffer.append('}');
- return buffer.toString();
- }
-
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/SparseIntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/SparseIntArray.java b/core/src/main/java/hivemall/utils/collections/SparseIntArray.java
deleted file mode 100644
index 7a4ba69..0000000
--- a/core/src/main/java/hivemall/utils/collections/SparseIntArray.java
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ArrayUtils;
-import hivemall.utils.lang.Preconditions;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-
-public final class SparseIntArray implements IntArray {
- private static final long serialVersionUID = -2814248784231540118L;
-
- private int[] mKeys;
- private int[] mValues;
- private int mSize;
-
- public SparseIntArray() {
- this(10);
- }
-
- public SparseIntArray(int initialCapacity) {
- mKeys = new int[initialCapacity];
- mValues = new int[initialCapacity];
- mSize = 0;
- }
-
- private SparseIntArray(int[] mKeys, int[] mValues, int mSize) {
- this.mKeys = mKeys;
- this.mValues = mValues;
- this.mSize = mSize;
- }
-
- public IntArray deepCopy() {
- int[] newKeys = new int[mSize];
- int[] newValues = new int[mSize];
- System.arraycopy(mKeys, 0, newKeys, 0, mSize);
- System.arraycopy(mValues, 0, newValues, 0, mSize);
- return new SparseIntArray(newKeys, newValues, mSize);
- }
-
- @Override
- public int get(int key) {
- return get(key, 0);
- }
-
- @Override
- public int get(int key, int valueIfKeyNotFound) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i < 0) {
- return valueIfKeyNotFound;
- } else {
- return mValues[i];
- }
- }
-
- public void delete(int key) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- removeAt(i);
- }
- }
-
- public void removeAt(int index) {
- System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1));
- System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1));
- mSize--;
- }
-
- @Override
- public void put(int key, int value) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- mValues[i] = value;
- } else {
- i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
- }
- }
-
- public void increment(int key, int value) {
- int i = Arrays.binarySearch(mKeys, 0, mSize, key);
- if (i >= 0) {
- mValues[i] += value;
- } else {
- i = ~i;
- mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
- mValues = ArrayUtils.insert(mValues, mSize, i, value);
- mSize++;
- }
- }
-
- @Override
- public int size() {
- return mSize;
- }
-
- @Override
- public int keyAt(int index) {
- return mKeys[index];
- }
-
- public int valueAt(int index) {
- return mValues[index];
- }
-
- public void setValueAt(int index, int value) {
- mValues[index] = value;
- }
-
- public int indexOfKey(int key) {
- return Arrays.binarySearch(mKeys, 0, mSize, key);
- }
-
- public int indexOfValue(int value) {
- for (int i = 0; i < mSize; i++) {
- if (mValues[i] == value) {
- return i;
- }
- }
- return -1;
- }
-
- public void clear() {
- clear(true);
- }
-
- public void clear(boolean zeroFill) {
- mSize = 0;
- if (zeroFill) {
- Arrays.fill(mKeys, 0);
- Arrays.fill(mValues, 0);
- }
- }
-
- public void append(int key, int value) {
- if (mSize != 0 && key <= mKeys[mSize - 1]) {
- put(key, value);
- return;
- }
- mKeys = ArrayUtils.append(mKeys, mSize, key);
- mValues = ArrayUtils.append(mValues, mSize, value);
- mSize++;
- }
-
- @Nonnull
- public int[] toArray() {
- return toArray(true);
- }
-
- @Override
- public int[] toArray(boolean copy) {
- if (mSize == 0) {
- return new int[0];
- }
-
- int last = mKeys[mSize - 1];
- final int[] array = new int[last + 1];
- for (int i = 0; i < mSize; i++) {
- int k = mKeys[i];
- int v = mValues[i];
- Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k);
- array[k] = v;
- }
- return array;
- }
-
- @Override
- public String toString() {
- if (size() <= 0) {
- return "{}";
- }
-
- StringBuilder buffer = new StringBuilder(mSize * 28);
- buffer.append('{');
- for (int i = 0; i < mSize; i++) {
- if (i > 0) {
- buffer.append(", ");
- }
- int key = keyAt(i);
- buffer.append(key);
- buffer.append('=');
- int value = valueAt(i);
- buffer.append(value);
- }
- buffer.append('}');
- return buffer.toString();
- }
-
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/DenseDoubleArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/DenseDoubleArray.java b/core/src/main/java/hivemall/utils/collections/arrays/DenseDoubleArray.java
new file mode 100644
index 0000000..f79f039
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/DenseDoubleArray.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+/**
+ * A fixed double array that has keys greater than or equals to 0.
+ */
+public final class DenseDoubleArray implements DoubleArray {
+ private static final long serialVersionUID = 4282904528662802088L;
+
+ @Nonnull
+ private final double[] array;
+ private final int size;
+
+ public DenseDoubleArray(@Nonnull int size) {
+ this.array = new double[size];
+ this.size = size;
+ }
+
+ public DenseDoubleArray(@Nonnull double[] array) {
+ this.array = array;
+ this.size = array.length;
+ }
+
+ @Override
+ public double get(int index) {
+ return array[index];
+ }
+
+ @Override
+ public double get(int index, double valueIfKeyNotFound) {
+ if (index >= size) {
+ return valueIfKeyNotFound;
+ }
+ return array[index];
+ }
+
+ @Override
+ public void put(int index, double value) {
+ array[index] = value;
+ }
+
+ @Override
+ public int size() {
+ return array.length;
+ }
+
+ @Override
+ public int keyAt(int index) {
+ return index;
+ }
+
+ @Override
+ public double[] toArray() {
+ return toArray(true);
+ }
+
+ @Override
+ public double[] toArray(boolean copy) {
+ if (copy) {
+ return Arrays.copyOf(array, size);
+ } else {
+ return array;
+ }
+ }
+
+ @Override
+ public void clear() {
+ Arrays.fill(array, 0.d);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
new file mode 100644
index 0000000..0869ff2
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/DenseIntArray.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+/**
+ * A fixed INT array that has keys greater than or equals to 0.
+ */
+public final class DenseIntArray implements IntArray {
+ private static final long serialVersionUID = -1450212841013810240L;
+
+ @Nonnull
+ private final int[] array;
+ private final int size;
+
+ public DenseIntArray(@Nonnull int size) {
+ this.array = new int[size];
+ this.size = size;
+ }
+
+ public DenseIntArray(@Nonnull int[] array) {
+ this.array = array;
+ this.size = array.length;
+ }
+
+ @Override
+ public int get(int index) {
+ return array[index];
+ }
+
+ @Override
+ public int get(int index, int valueIfKeyNotFound) {
+ if (index >= size) {
+ return valueIfKeyNotFound;
+ }
+ return array[index];
+ }
+
+ @Override
+ public void put(int index, int value) {
+ array[index] = value;
+ }
+
+ @Override
+ public void increment(int index, int value) {
+ array[index] += value;
+ }
+
+ @Override
+ public int size() {
+ return array.length;
+ }
+
+ @Override
+ public int keyAt(int index) {
+ return index;
+ }
+
+ @Override
+ public int[] toArray() {
+ return toArray(true);
+ }
+
+ @Override
+ public int[] toArray(boolean copy) {
+ if (copy) {
+ return Arrays.copyOf(array, size);
+ } else {
+ return array;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray.java b/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray.java
new file mode 100644
index 0000000..c8f3e17
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public interface DoubleArray extends Serializable {
+
+ public double get(int key);
+
+ public double get(int key, double valueIfKeyNotFound);
+
+ public void put(int key, double value);
+
+ public int size();
+
+ public int keyAt(int index);
+
+ @Nonnull
+ public double[] toArray();
+
+ @Nonnull
+ public double[] toArray(boolean copy);
+
+ public void clear();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray3D.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray3D.java b/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray3D.java
new file mode 100644
index 0000000..35feff9
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/DoubleArray3D.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.lang.Primitives;
+
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+
+import javax.annotation.Nonnull;
+
+public final class DoubleArray3D {
+ private static final int DEFAULT_SIZE = 100 * 100 * 10; // feature * field * factor
+
+ private final boolean direct;
+
+ @Nonnull
+ private DoubleBuffer buffer;
+ private int capacity;
+
+ private int size;
+ // number of array in each dimension
+ private int n1, n2, n3;
+ // pointer to each dimension
+ private int p1, p2;
+
+ private boolean sanityCheck;
+
+ public DoubleArray3D() {
+ this(DEFAULT_SIZE, true);
+ }
+
+ public DoubleArray3D(int initSize, boolean direct) {
+ this.direct = direct;
+ this.buffer = allocate(direct, initSize);
+ this.capacity = initSize;
+ this.size = -1;
+ this.sanityCheck = true;
+ }
+
+ public DoubleArray3D(int dim1, int dim2, int dim3) {
+ this.direct = true;
+ this.capacity = -1;
+ configure(dim1, dim2, dim3);
+ this.sanityCheck = true;
+ }
+
+ public void setSanityCheck(boolean enable) {
+ this.sanityCheck = enable;
+ }
+
+ public void configure(final int dim1, final int dim2, final int dim3) {
+ int requiredSize = cardinarity(dim1, dim2, dim3);
+ if (requiredSize > capacity) {
+ this.buffer = allocate(direct, requiredSize);
+ this.capacity = requiredSize;
+ }
+ this.size = requiredSize;
+ this.n1 = dim1;
+ this.n2 = dim2;
+ this.n3 = dim3;
+ this.p1 = n2 * n3;
+ this.p2 = n3;
+ }
+
+ public void clear() {
+ buffer.clear();
+ this.size = -1;
+ }
+
+ public int getSize() {
+ return size;
+ }
+
+ int getCapacity() {
+ return capacity;
+ }
+
+ public double get(final int i, final int j, final int k) {
+ int idx = idx(i, j, k);
+ return buffer.get(idx);
+ }
+
+ public void set(final int i, final int j, final int k, final double val) {
+ int idx = idx(i, j, k);
+ buffer.put(idx, val);
+ }
+
+ private int idx(final int i, final int j, final int k) {
+ if (sanityCheck == false) {
+ return i * p1 + j * p2 + k;
+ }
+
+ if (size == -1) {
+ throw new IllegalStateException("Double3DArray#configure() is not called");
+ }
+ if (i >= n1 || i < 0) {
+ throw new ArrayIndexOutOfBoundsException("Index '" + i
+ + "' out of bounds for 1st dimension of size " + n1);
+ }
+ if (j >= n2 || j < 0) {
+ throw new ArrayIndexOutOfBoundsException("Index '" + j
+ + "' out of bounds for 2nd dimension of size " + n2);
+ }
+ if (k >= n3 || k < 0) {
+ throw new ArrayIndexOutOfBoundsException("Index '" + k
+ + "' out of bounds for 3rd dimension of size " + n3);
+ }
+ final int idx = i * p1 + j * p2 + k;
+ if (idx >= size) {
+ throw new IndexOutOfBoundsException("Computed internal index '" + idx
+ + "' exceeds buffer size '" + size + "' where i=" + i + ", j=" + j + ", k=" + k);
+ }
+ return idx;
+ }
+
+ private static int cardinarity(final int dim1, final int dim2, final int dim3) {
+ if (dim1 <= 0 || dim2 <= 0 || dim3 <= 0) {
+ throw new IllegalArgumentException("Detected negative dimension size. dim1=" + dim1
+ + ", dim2=" + dim2 + ", dim3=" + dim3);
+ }
+ return dim1 * dim2 * dim3;
+ }
+
+ @Nonnull
+ private static DoubleBuffer allocate(final boolean direct, final int size) {
+ int bytes = size * Primitives.DOUBLE_BYTES;
+ ByteBuffer buf = direct ? ByteBuffer.allocateDirect(bytes) : ByteBuffer.allocate(bytes);
+ return buf.asDoubleBuffer();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/FloatArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/FloatArray.java b/core/src/main/java/hivemall/utils/collections/arrays/FloatArray.java
new file mode 100644
index 0000000..b72bdef
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/FloatArray.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public interface FloatArray extends Serializable {
+
+ public float get(int key);
+
+ public float get(int key, float valueIfKeyNotFound);
+
+ public void put(int key, float value);
+
+ public int size();
+
+ public int keyAt(int index);
+
+ @Nonnull
+ public float[] toArray();
+
+ @Nonnull
+ public float[] toArray(boolean copy);
+
+ public void clear();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
new file mode 100644
index 0000000..8edb0d4
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/IntArray.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public interface IntArray extends Serializable {
+
+ public int get(int key);
+
+ public int get(int key, int valueIfKeyNotFound);
+
+ public void put(int key, int value);
+
+ public void increment(int key, int value);
+
+ public int size();
+
+ public int keyAt(int index);
+
+ @Nonnull
+ public int[] toArray();
+
+ @Nonnull
+ public int[] toArray(boolean copy);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/SparseDoubleArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/SparseDoubleArray.java b/core/src/main/java/hivemall/utils/collections/arrays/SparseDoubleArray.java
new file mode 100644
index 0000000..ac41951
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/SparseDoubleArray.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+public final class SparseDoubleArray implements DoubleArray {
+ private static final long serialVersionUID = -2814248784231540118L;
+
+ @Nonnull
+ private int[] mKeys;
+ @Nonnull
+ private double[] mValues;
+ private int mSize;
+
+ public SparseDoubleArray() {
+ this(10);
+ }
+
+ public SparseDoubleArray(int initialCapacity) {
+ mKeys = new int[initialCapacity];
+ mValues = new double[initialCapacity];
+ mSize = 0;
+ }
+
+ private SparseDoubleArray(@Nonnull int[] mKeys, @Nonnull double[] mValues, int mSize) {
+ this.mKeys = mKeys;
+ this.mValues = mValues;
+ this.mSize = mSize;
+ }
+
+ @Nonnull
+ public SparseDoubleArray deepCopy() {
+ int[] newKeys = new int[mSize];
+ double[] newValues = new double[mSize];
+ System.arraycopy(mKeys, 0, newKeys, 0, mSize);
+ System.arraycopy(mValues, 0, newValues, 0, mSize);
+ return new SparseDoubleArray(newKeys, newValues, mSize);
+ }
+
+ @Override
+ public double get(int key) {
+ return get(key, 0);
+ }
+
+ @Override
+ public double get(int key, double valueIfKeyNotFound) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i < 0) {
+ return valueIfKeyNotFound;
+ } else {
+ return mValues[i];
+ }
+ }
+
+ public void delete(int key) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ removeAt(i);
+ }
+ }
+
+ public void removeAt(int index) {
+ System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1));
+ System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1));
+ mSize--;
+ }
+
+ @Override
+ public void put(int key, double value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] = value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ public void increment(int key, double value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] += value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ @Override
+ public int size() {
+ return mSize;
+ }
+
+ @Override
+ public int keyAt(int index) {
+ return mKeys[index];
+ }
+
+ public double valueAt(int index) {
+ return mValues[index];
+ }
+
+ public void setValueAt(int index, double value) {
+ mValues[index] = value;
+ }
+
+ public int indexOfKey(int key) {
+ return Arrays.binarySearch(mKeys, 0, mSize, key);
+ }
+
+ public int indexOfValue(double value) {
+ for (int i = 0; i < mSize; i++) {
+ if (mValues[i] == value) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ @Override
+ public void clear() {
+ clear(true);
+ }
+
+ public void clear(boolean zeroFill) {
+ mSize = 0;
+ if (zeroFill) {
+ Arrays.fill(mKeys, 0);
+ Arrays.fill(mValues, 0.d);
+ }
+ }
+
+ public void append(int key, double value) {
+ if (mSize != 0 && key <= mKeys[mSize - 1]) {
+ put(key, value);
+ return;
+ }
+ mKeys = ArrayUtils.append(mKeys, mSize, key);
+ mValues = ArrayUtils.append(mValues, mSize, value);
+ mSize++;
+ }
+
+ @Override
+ public double[] toArray() {
+ return toArray(true);
+ }
+
+ @Override
+ public double[] toArray(boolean copy) {
+ if (mSize == 0) {
+ return new double[0];
+ }
+
+ int last = mKeys[mSize - 1];
+ final double[] array = new double[last + 1];
+ for (int i = 0; i < mSize; i++) {
+ int k = mKeys[i];
+ double v = mValues[i];
+ Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k);
+ array[k] = v;
+ }
+ return array;
+ }
+
+ public void each(@Nonnull final VectorProcedure procedure) {
+ for (int i = 0; i < mSize; i++) {
+ int k = mKeys[i];
+ double v = mValues[i];
+ procedure.apply(k, v);
+ }
+ }
+
+ @Override
+ public String toString() {
+ if (size() <= 0) {
+ return "{}";
+ }
+
+ StringBuilder buffer = new StringBuilder(mSize * 28);
+ buffer.append('{');
+ for (int i = 0; i < mSize; i++) {
+ if (i > 0) {
+ buffer.append(", ");
+ }
+ int key = keyAt(i);
+ buffer.append(key);
+ buffer.append('=');
+ double value = valueAt(i);
+ buffer.append(value);
+ }
+ buffer.append('}');
+ return buffer.toString();
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/SparseFloatArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/SparseFloatArray.java b/core/src/main/java/hivemall/utils/collections/arrays/SparseFloatArray.java
new file mode 100644
index 0000000..928de77
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/SparseFloatArray.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+public final class SparseFloatArray implements FloatArray {
+ private static final long serialVersionUID = -2814248784231540118L;
+
+ private int[] mKeys;
+ private float[] mValues;
+ private int mSize;
+
+ public SparseFloatArray() {
+ this(10);
+ }
+
+ public SparseFloatArray(int initialCapacity) {
+ mKeys = new int[initialCapacity];
+ mValues = new float[initialCapacity];
+ mSize = 0;
+ }
+
+ private SparseFloatArray(@Nonnull int[] mKeys, @Nonnull float[] mValues, int mSize) {
+ this.mKeys = mKeys;
+ this.mValues = mValues;
+ this.mSize = mSize;
+ }
+
+ public SparseFloatArray deepCopy() {
+ int[] newKeys = new int[mSize];
+ float[] newValues = new float[mSize];
+ System.arraycopy(mKeys, 0, newKeys, 0, mSize);
+ System.arraycopy(mValues, 0, newValues, 0, mSize);
+ return new SparseFloatArray(newKeys, newValues, mSize);
+ }
+
+ @Override
+ public float get(int key) {
+ return get(key, 0.f);
+ }
+
+ @Override
+ public float get(int key, float valueIfKeyNotFound) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i < 0) {
+ return valueIfKeyNotFound;
+ } else {
+ return mValues[i];
+ }
+ }
+
+ public void delete(int key) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ removeAt(i);
+ }
+ }
+
+ public void removeAt(int index) {
+ System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1));
+ System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1));
+ mSize--;
+ }
+
+ @Override
+ public void put(int key, float value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] = value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ public void increment(int key, float value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] += value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ @Override
+ public int size() {
+ return mSize;
+ }
+
+ @Override
+ public int keyAt(int index) {
+ return mKeys[index];
+ }
+
+ public float valueAt(int index) {
+ return mValues[index];
+ }
+
+ public void setValueAt(int index, float value) {
+ mValues[index] = value;
+ }
+
+ public int indexOfKey(int key) {
+ return Arrays.binarySearch(mKeys, 0, mSize, key);
+ }
+
+ public int indexOfValue(float value) {
+ for (int i = 0; i < mSize; i++) {
+ if (mValues[i] == value) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ public void clear() {
+ clear(true);
+ }
+
+ public void clear(boolean zeroFill) {
+ mSize = 0;
+ if (zeroFill) {
+ Arrays.fill(mKeys, 0);
+ Arrays.fill(mValues, 0.f);
+ }
+ }
+
+ public void append(int key, float value) {
+ if (mSize != 0 && key <= mKeys[mSize - 1]) {
+ put(key, value);
+ return;
+ }
+ mKeys = ArrayUtils.append(mKeys, mSize, key);
+ mValues = ArrayUtils.append(mValues, mSize, value);
+ mSize++;
+ }
+
+ @Nonnull
+ public float[] toArray() {
+ return toArray(true);
+ }
+
+ @Override
+ public float[] toArray(boolean copy) {
+ if (mSize == 0) {
+ return new float[0];
+ }
+
+ int last = mKeys[mSize - 1];
+ final float[] array = new float[last + 1];
+ for (int i = 0; i < mSize; i++) {
+ int k = mKeys[i];
+ float v = mValues[i];
+ Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k);
+ array[k] = v;
+ }
+ return array;
+ }
+
+ @Override
+ public String toString() {
+ if (size() <= 0) {
+ return "{}";
+ }
+
+ StringBuilder buffer = new StringBuilder(mSize * 28);
+ buffer.append('{');
+ for (int i = 0; i < mSize; i++) {
+ if (i > 0) {
+ buffer.append(", ");
+ }
+ int key = keyAt(i);
+ buffer.append(key);
+ buffer.append('=');
+ float value = valueAt(i);
+ buffer.append(value);
+ }
+ buffer.append('}');
+ return buffer.toString();
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java b/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
new file mode 100644
index 0000000..8de5476
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/arrays/SparseIntArray.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+public final class SparseIntArray implements IntArray {
+ private static final long serialVersionUID = -2814248784231540118L;
+
+ private int[] mKeys;
+ private int[] mValues;
+ private int mSize;
+
+ public SparseIntArray() {
+ this(10);
+ }
+
+ public SparseIntArray(int initialCapacity) {
+ mKeys = new int[initialCapacity];
+ mValues = new int[initialCapacity];
+ mSize = 0;
+ }
+
+ private SparseIntArray(int[] mKeys, int[] mValues, int mSize) {
+ this.mKeys = mKeys;
+ this.mValues = mValues;
+ this.mSize = mSize;
+ }
+
+ public IntArray deepCopy() {
+ int[] newKeys = new int[mSize];
+ int[] newValues = new int[mSize];
+ System.arraycopy(mKeys, 0, newKeys, 0, mSize);
+ System.arraycopy(mValues, 0, newValues, 0, mSize);
+ return new SparseIntArray(newKeys, newValues, mSize);
+ }
+
+ @Override
+ public int get(int key) {
+ return get(key, 0);
+ }
+
+ @Override
+ public int get(int key, int valueIfKeyNotFound) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i < 0) {
+ return valueIfKeyNotFound;
+ } else {
+ return mValues[i];
+ }
+ }
+
+ public void delete(int key) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ removeAt(i);
+ }
+ }
+
+ public void removeAt(int index) {
+ System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1));
+ System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1));
+ mSize--;
+ }
+
+ @Override
+ public void put(int key, int value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] = value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ @Override
+ public void increment(int key, int value) {
+ int i = Arrays.binarySearch(mKeys, 0, mSize, key);
+ if (i >= 0) {
+ mValues[i] += value;
+ } else {
+ i = ~i;
+ mKeys = ArrayUtils.insert(mKeys, mSize, i, key);
+ mValues = ArrayUtils.insert(mValues, mSize, i, value);
+ mSize++;
+ }
+ }
+
+ @Override
+ public int size() {
+ return mSize;
+ }
+
+ @Override
+ public int keyAt(int index) {
+ return mKeys[index];
+ }
+
+ public int valueAt(int index) {
+ return mValues[index];
+ }
+
+ public void setValueAt(int index, int value) {
+ mValues[index] = value;
+ }
+
+ public int indexOfKey(int key) {
+ return Arrays.binarySearch(mKeys, 0, mSize, key);
+ }
+
+ public int indexOfValue(int value) {
+ for (int i = 0; i < mSize; i++) {
+ if (mValues[i] == value) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ public void clear() {
+ clear(true);
+ }
+
+ public void clear(boolean zeroFill) {
+ mSize = 0;
+ if (zeroFill) {
+ Arrays.fill(mKeys, 0);
+ Arrays.fill(mValues, 0);
+ }
+ }
+
+ public void append(int key, int value) {
+ if (mSize != 0 && key <= mKeys[mSize - 1]) {
+ put(key, value);
+ return;
+ }
+ mKeys = ArrayUtils.append(mKeys, mSize, key);
+ mValues = ArrayUtils.append(mValues, mSize, value);
+ mSize++;
+ }
+
+ @Nonnull
+ public int[] toArray() {
+ return toArray(true);
+ }
+
+ @Override
+ public int[] toArray(boolean copy) {
+ if (mSize == 0) {
+ return new int[0];
+ }
+
+ int last = mKeys[mSize - 1];
+ final int[] array = new int[last + 1];
+ for (int i = 0; i < mSize; i++) {
+ int k = mKeys[i];
+ int v = mValues[i];
+ Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k);
+ array[k] = v;
+ }
+ return array;
+ }
+
+ @Override
+ public String toString() {
+ if (size() <= 0) {
+ return "{}";
+ }
+
+ StringBuilder buffer = new StringBuilder(mSize * 28);
+ buffer.append('{');
+ for (int i = 0; i < mSize; i++) {
+ if (i > 0) {
+ buffer.append(", ");
+ }
+ int key = keyAt(i);
+ buffer.append(key);
+ buffer.append('=');
+ int value = valueAt(i);
+ buffer.append(value);
+ }
+ buffer.append('}');
+ return buffer.toString();
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/lists/DoubleArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/lists/DoubleArrayList.java b/core/src/main/java/hivemall/utils/collections/lists/DoubleArrayList.java
new file mode 100644
index 0000000..feb614f
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/lists/DoubleArrayList.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public final class DoubleArrayList implements Serializable {
+ private static final long serialVersionUID = -8155789759545975413L;
+ public static final int DEFAULT_CAPACITY = 12;
+
+ /** array entity */
+ private double[] data;
+ private int used;
+
+ public DoubleArrayList() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ public DoubleArrayList(int size) {
+ this.data = new double[size];
+ this.used = 0;
+ }
+
+ public DoubleArrayList(double[] initValues) {
+ this.data = initValues;
+ this.used = initValues.length;
+ }
+
+ @Nonnull
+ public DoubleArrayList add(double value) {
+ if (used >= data.length) {
+ expand(used + 1);
+ }
+ data[used++] = value;
+ return this;
+ }
+
+ @Nonnull
+ public DoubleArrayList add(@Nonnull double[] values) {
+ final int needs = used + values.length;
+ if (needs >= data.length) {
+ expand(needs);
+ }
+ System.arraycopy(values, 0, data, used, values.length);
+ this.used = needs;
+ return this;
+ }
+
+ /**
+ * dynamic expansion.
+ */
+ private void expand(int max) {
+ while (data.length < max) {
+ final int len = data.length;
+ double[] newArray = new double[len * 2];
+ System.arraycopy(data, 0, newArray, 0, len);
+ this.data = newArray;
+ }
+ }
+
+ public double remove() {
+ return data[--used];
+ }
+
+ public double remove(int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ final double ret;
+ if (index == used) {
+ ret = data[index];
+ --used;
+ } else { // index < used
+ ret = data[index];
+ System.arraycopy(data, index + 1, data, index, used - index - 1);
+ --used;
+ }
+ return ret;
+ }
+
+ public void set(int index, double value) {
+ if (index > used) {
+ throw new IllegalArgumentException("Index MUST be less than \"size()\".");
+ } else if (index == used) {
+ ++used;
+ }
+ data[index] = value;
+ }
+
+ public double get(int index) {
+ if (index >= used)
+ throw new IndexOutOfBoundsException();
+ return data[index];
+ }
+
+ public double fastGet(int index) {
+ return data[index];
+ }
+
+ public int size() {
+ return used;
+ }
+
+ public boolean isEmpty() {
+ return used == 0;
+ }
+
+ public void clear() {
+ used = 0;
+ }
+
+ @Nonnull
+ public double[] toArray() {
+ return toArray(false);
+ }
+
+ @Nonnull
+ public double[] toArray(boolean close) {
+ final double[] newArray = new double[used];
+ System.arraycopy(data, 0, newArray, 0, used);
+ if (close) {
+ this.data = null;
+ }
+ return newArray;
+ }
+
+ public double[] array() {
+ return data;
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ for (int i = 0; i < used; i++) {
+ if (i != 0) {
+ buf.append(", ");
+ }
+ buf.append(data[i]);
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/lists/FloatArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/lists/FloatArrayList.java b/core/src/main/java/hivemall/utils/collections/lists/FloatArrayList.java
new file mode 100644
index 0000000..54b214d
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/lists/FloatArrayList.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public final class FloatArrayList implements Serializable {
+ private static final long serialVersionUID = 8764828070342317585L;
+
+ public static final int DEFAULT_CAPACITY = 12;
+
+ /** array entity */
+ private float[] data;
+ private int used;
+
+ public FloatArrayList() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ public FloatArrayList(int size) {
+ this.data = new float[size];
+ this.used = 0;
+ }
+
+ public FloatArrayList(float[] initValues) {
+ this.data = initValues;
+ this.used = initValues.length;
+ }
+
+ @Nonnull
+ public FloatArrayList add(float value) {
+ if (used >= data.length) {
+ expand(used + 1);
+ }
+ data[used++] = value;
+ return this;
+ }
+
+ @Nonnull
+ public FloatArrayList add(@Nonnull float[] values) {
+ final int needs = used + values.length;
+ if (needs >= data.length) {
+ expand(needs);
+ }
+ System.arraycopy(values, 0, data, used, values.length);
+ this.used = needs;
+ return this;
+ }
+
+ /**
+ * dynamic expansion.
+ */
+ private void expand(int max) {
+ while (data.length < max) {
+ final int len = data.length;
+ float[] newArray = new float[len * 2];
+ System.arraycopy(data, 0, newArray, 0, len);
+ this.data = newArray;
+ }
+ }
+
+ public float remove() {
+ return data[--used];
+ }
+
+ public float remove(int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ final float ret;
+ if (index == used) {
+ ret = data[index];
+ --used;
+ } else { // index < used
+ ret = data[index];
+ System.arraycopy(data, index + 1, data, index, used - index - 1);
+ --used;
+ }
+ return ret;
+ }
+
+ public void set(int index, float value) {
+ if (index > used) {
+ throw new IllegalArgumentException("Index MUST be less than \"size()\".");
+ } else if (index == used) {
+ ++used;
+ }
+ data[index] = value;
+ }
+
+ public float get(int index) {
+ if (index >= used)
+ throw new IndexOutOfBoundsException();
+ return data[index];
+ }
+
+ public float fastGet(int index) {
+ return data[index];
+ }
+
+ public int size() {
+ return used;
+ }
+
+ public boolean isEmpty() {
+ return used == 0;
+ }
+
+ public void clear() {
+ this.used = 0;
+ }
+
+ public float[] toArray() {
+ return toArray(false);
+ }
+
+ public float[] toArray(boolean close) {
+ final float[] newArray = new float[used];
+ System.arraycopy(data, 0, newArray, 0, used);
+ if (close) {
+ this.data = null;
+ }
+ return newArray;
+ }
+
+ public float[] array() {
+ return data;
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ for (int i = 0; i < used; i++) {
+ if (i != 0) {
+ buf.append(", ");
+ }
+ buf.append(data[i]);
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/lists/IntArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/lists/IntArrayList.java b/core/src/main/java/hivemall/utils/collections/lists/IntArrayList.java
new file mode 100644
index 0000000..ea17c5f
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/lists/IntArrayList.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+import hivemall.utils.lang.ArrayUtils;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public final class IntArrayList implements Serializable {
+ private static final long serialVersionUID = -2147675120406747488L;
+ public static final int DEFAULT_CAPACITY = 12;
+
+ /** array entity */
+ private int[] data;
+ private int used;
+
+ public IntArrayList() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ public IntArrayList(int size) {
+ this.data = new int[size];
+ this.used = 0;
+ }
+
+ public IntArrayList(int[] initValues) {
+ this.data = initValues;
+ this.used = initValues.length;
+ }
+
+ @Nonnull
+ public IntArrayList add(final int value) {
+ if (used >= data.length) {
+ expand(used + 1);
+ }
+ data[used++] = value;
+ return this;
+ }
+
+ @Nonnull
+ public IntArrayList add(@Nonnull final int[] values) {
+ final int needs = used + values.length;
+ if (needs >= data.length) {
+ expand(needs);
+ }
+ System.arraycopy(values, 0, data, used, values.length);
+ this.used = needs;
+ return this;
+ }
+
+ /**
+ * dynamic expansion.
+ */
+ private void expand(final int max) {
+ while (data.length < max) {
+ final int len = data.length;
+ int[] newArray = new int[len * 2];
+ System.arraycopy(data, 0, newArray, 0, len);
+ this.data = newArray;
+ }
+ }
+
+ public int remove() {
+ return data[--used];
+ }
+
+ public int remove(final int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ final int ret;
+ if (index == used) {
+ ret = data[index];
+ --used;
+ } else { // index < used
+ ret = data[index];
+ System.arraycopy(data, index + 1, data, index, used - index - 1);
+ --used;
+ }
+ return ret;
+ }
+
+ public void set(final int index, final int value) {
+ if (index > used) {
+ throw new IllegalArgumentException("Index " + index + " MUST be less than size() "
+ + used);
+ } else if (index == used) {
+ ++used;
+ }
+ data[index] = value;
+ }
+
+ public int get(final int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException("Index " + index + " out of bounds " + used);
+ }
+ return data[index];
+ }
+
+ public int fastGet(final int index) {
+ return data[index];
+ }
+
+ /**
+ * @return -1 if not found.
+ */
+ public int indexOf(final int key) {
+ return ArrayUtils.indexOf(data, key, 0, used);
+ }
+
+ public boolean contains(final int key) {
+ return ArrayUtils.indexOf(data, key, 0, used) != -1;
+ }
+
+ public int size() {
+ return used;
+ }
+
+ public boolean isEmpty() {
+ return used == 0;
+ }
+
+ public void clear() {
+ used = 0;
+ }
+
+ @Nonnull
+ public int[] toArray() {
+ return toArray(false);
+ }
+
+ @Nonnull
+ public int[] toArray(boolean close) {
+ final int[] newArray = new int[used];
+ System.arraycopy(data, 0, newArray, 0, used);
+ if (close) {
+ this.data = null;
+ }
+ return newArray;
+ }
+
+ public int[] array() {
+ return data;
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ for (int i = 0; i < used; i++) {
+ if (i != 0) {
+ buf.append(", ");
+ }
+ buf.append(data[i]);
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/lists/LongArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/lists/LongArrayList.java b/core/src/main/java/hivemall/utils/collections/lists/LongArrayList.java
new file mode 100644
index 0000000..0786872
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/lists/LongArrayList.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+import java.io.Serializable;
+
+import javax.annotation.Nonnull;
+
+public final class LongArrayList implements Serializable {
+ private static final long serialVersionUID = 6928415231676568533L;
+
+ public static final int DEFAULT_CAPACITY = 12;
+
+ /** array entity */
+ private long[] data;
+ private int used;
+
+ public LongArrayList() {
+ this(DEFAULT_CAPACITY);
+ }
+
+ public LongArrayList(int size) {
+ this.data = new long[size];
+ this.used = 0;
+ }
+
+ public LongArrayList(@Nonnull long[] initValues) {
+ this.data = initValues;
+ this.used = initValues.length;
+ }
+
+ @Nonnull
+ public LongArrayList add(final long value) {
+ if (used >= data.length) {
+ expand(used + 1);
+ }
+ data[used++] = value;
+ return this;
+ }
+
+ @Nonnull
+ public LongArrayList add(@Nonnull final long[] values) {
+ final int needs = used + values.length;
+ if (needs >= data.length) {
+ expand(needs);
+ }
+ System.arraycopy(values, 0, data, used, values.length);
+ this.used = needs;
+ return this;
+ }
+
+ /**
+ * dynamic expansion.
+ */
+ private void expand(final int max) {
+ while (data.length < max) {
+ final int len = data.length;
+ long[] newArray = new long[len * 2];
+ System.arraycopy(data, 0, newArray, 0, len);
+ this.data = newArray;
+ }
+ }
+
+ public long remove() {
+ return data[--used];
+ }
+
+ public long remove(final int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ final long ret;
+ if (index == used) {
+ ret = data[index];
+ --used;
+ } else { // index < used
+ ret = data[index];
+ System.arraycopy(data, index + 1, data, index, used - index - 1);
+ --used;
+ }
+ return ret;
+ }
+
+ public void set(final int index, final long value) {
+ if (index > used) {
+ throw new IllegalArgumentException("Index MUST be less than \"size()\".");
+ } else if (index == used) {
+ ++used;
+ }
+ data[index] = value;
+ }
+
+ public long get(final int index) {
+ if (index >= used) {
+ throw new IndexOutOfBoundsException();
+ }
+ return data[index];
+ }
+
+ public long fastGet(final int index) {
+ return data[index];
+ }
+
+ public int size() {
+ return used;
+ }
+
+ public boolean isEmpty() {
+ return used == 0;
+ }
+
+ public void clear() {
+ this.used = 0;
+ }
+
+ @Nonnull
+ public long[] toArray() {
+ return toArray(false);
+ }
+
+ @Nonnull
+ public long[] toArray(boolean close) {
+ final long[] newArray = new long[used];
+ System.arraycopy(data, 0, newArray, 0, used);
+ if (close) {
+ this.data = null;
+ }
+ return newArray;
+ }
+
+ @Nonnull
+ public long[] array() {
+ return data;
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder buf = new StringBuilder();
+ buf.append('[');
+ for (int i = 0; i < used; i++) {
+ if (i != 0) {
+ buf.append(", ");
+ }
+ buf.append(data[i]);
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+}
[12/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
Close #51: [HIVEMALL-75] Support Sparse Vector Format as the input of RandomForest
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8dc3a024
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8dc3a024
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8dc3a024
Branch: refs/heads/master
Commit: 8dc3a024d9b2708f297a886a3256e7107bc276f9
Parents: 7956b5f
Author: myui <my...@apache.org>
Authored: Mon Apr 10 06:31:49 2017 +0900
Committer: myui <yu...@gmail.com>
Committed: Mon Apr 10 06:31:49 2017 +0900
----------------------------------------------------------------------
core/pom.xml | 10 +-
.../java/hivemall/annotations/Immutable.java | 34 +
.../main/java/hivemall/annotations/Mutable.java | 36 ++
.../KernelExpansionPassiveAggressiveUDTF.java | 4 +-
.../java/hivemall/common/ReservoirSampler.java | 100 ---
.../java/hivemall/fm/FFMPredictionModel.java | 4 +-
.../hivemall/fm/FFMStringFeatureMapModel.java | 2 +-
.../java/hivemall/fm/FMIntFeatureMapModel.java | 4 +-
.../hivemall/fm/FMStringFeatureMapModel.java | 2 +-
.../fm/FieldAwareFactorizationMachineModel.java | 4 +-
.../fm/FieldAwareFactorizationMachineUDTF.java | 4 +-
.../hivemall/ftvec/ranking/BprSamplingUDTF.java | 2 +-
.../ranking/PerEventPositiveOnlyFeedback.java | 2 +-
.../ftvec/ranking/PositiveOnlyFeedback.java | 6 +-
.../hivemall/math/matrix/AbstractMatrix.java | 105 +++
.../hivemall/math/matrix/ColumnMajorMatrix.java | 59 ++
.../main/java/hivemall/math/matrix/Matrix.java | 127 ++++
.../java/hivemall/math/matrix/MatrixUtils.java | 73 +++
.../hivemall/math/matrix/RowMajorMatrix.java | 69 ++
.../math/matrix/builders/CSCMatrixBuilder.java | 121 ++++
.../math/matrix/builders/CSRMatrixBuilder.java | 77 +++
.../builders/ColumnMajorDenseMatrixBuilder.java | 81 +++
.../math/matrix/builders/DoKMatrixBuilder.java | 56 ++
.../math/matrix/builders/MatrixBuilder.java | 91 +++
.../builders/RowMajorDenseMatrixBuilder.java | 79 +++
.../matrix/dense/ColumnMajorDenseMatrix2d.java | 300 +++++++++
.../matrix/dense/RowMajorDenseMatrix2d.java | 349 ++++++++++
.../math/matrix/ints/AbstractIntMatrix.java | 112 ++++
.../ints/ColumnMajorDenseIntMatrix2d.java | 172 +++++
.../math/matrix/ints/ColumnMajorIntMatrix.java | 39 ++
.../hivemall/math/matrix/ints/DoKIntMatrix.java | 277 ++++++++
.../hivemall/math/matrix/ints/IntMatrix.java | 104 +++
.../hivemall/math/matrix/sparse/CSCMatrix.java | 289 +++++++++
.../hivemall/math/matrix/sparse/CSRMatrix.java | 282 ++++++++
.../hivemall/math/matrix/sparse/DoKMatrix.java | 332 ++++++++++
.../hivemall/math/random/CommonsMathRandom.java | 63 ++
.../java/hivemall/math/random/JavaRandom.java | 61 ++
.../main/java/hivemall/math/random/PRNG.java | 39 ++
.../random/RandomNumberGeneratorFactory.java | 103 +++
.../java/hivemall/math/random/SmileRandom.java | 63 ++
.../hivemall/math/vector/AbstractVector.java | 44 ++
.../java/hivemall/math/vector/DenseVector.java | 90 +++
.../java/hivemall/math/vector/SparseVector.java | 76 +++
.../main/java/hivemall/math/vector/Vector.java | 46 ++
.../hivemall/math/vector/VectorProcedure.java | 33 +
.../java/hivemall/matrix/CSRMatrixBuilder.java | 83 ---
.../hivemall/matrix/DenseMatrixBuilder.java | 79 ---
core/src/main/java/hivemall/matrix/Matrix.java | 92 ---
.../java/hivemall/matrix/MatrixBuilder.java | 89 ---
.../java/hivemall/matrix/ReadOnlyCSRMatrix.java | 135 ----
.../hivemall/matrix/ReadOnlyDenseMatrix2d.java | 102 ---
.../main/java/hivemall/mf/FactorizedModel.java | 2 +-
.../hivemall/model/AbstractPredictionModel.java | 4 +-
.../main/java/hivemall/model/SparseModel.java | 2 +-
.../src/main/java/hivemall/smile/ModelType.java | 85 ---
.../smile/classification/DecisionTree.java | 495 +++++++-------
.../GradientTreeBoostingClassifierUDTF.java | 228 +++----
.../smile/classification/PredictionHandler.java | 27 +
.../RandomForestClassifierUDTF.java | 366 +++++++----
.../java/hivemall/smile/data/Attribute.java | 44 +-
.../regression/RandomForestRegressionUDTF.java | 211 +++---
.../smile/regression/RegressionTree.java | 377 ++++++-----
.../smile/tools/RandomForestEnsembleUDAF.java | 328 +++++++---
.../hivemall/smile/tools/TreePredictUDF.java | 407 +++++-------
.../hivemall/smile/utils/SmileExtUtils.java | 215 +++++--
.../main/java/hivemall/smile/vm/Operation.java | 52 --
.../java/hivemall/smile/vm/StackMachine.java | 300 ---------
.../hivemall/smile/vm/VMRuntimeException.java | 32 -
.../tools/mapred/DistributedCacheLookupUDF.java | 2 +-
.../hivemall/utils/collections/DoubleArray.java | 43 --
.../utils/collections/DoubleArray3D.java | 147 -----
.../utils/collections/DoubleArrayList.java | 168 -----
.../utils/collections/FixedIntArray.java | 87 ---
.../utils/collections/FloatArrayList.java | 152 -----
.../collections/Int2FloatOpenHashTable.java | 418 ------------
.../utils/collections/Int2IntOpenHashTable.java | 414 ------------
.../collections/Int2LongOpenHashTable.java | 500 --------------
.../hivemall/utils/collections/IntArray.java | 43 --
.../utils/collections/IntArrayList.java | 183 ------
.../utils/collections/IntOpenHashMap.java | 467 --------------
.../utils/collections/IntOpenHashTable.java | 338 ----------
.../java/hivemall/utils/collections/LRUMap.java | 41 --
.../hivemall/utils/collections/OpenHashMap.java | 350 ----------
.../utils/collections/OpenHashTable.java | 412 ------------
.../utils/collections/SparseDoubleArray.java | 213 ------
.../utils/collections/SparseIntArray.java | 210 ------
.../collections/arrays/DenseDoubleArray.java | 92 +++
.../utils/collections/arrays/DenseIntArray.java | 92 +++
.../utils/collections/arrays/DoubleArray.java | 45 ++
.../utils/collections/arrays/DoubleArray3D.java | 147 +++++
.../utils/collections/arrays/FloatArray.java | 45 ++
.../utils/collections/arrays/IntArray.java | 45 ++
.../collections/arrays/SparseDoubleArray.java | 223 +++++++
.../collections/arrays/SparseFloatArray.java | 210 ++++++
.../collections/arrays/SparseIntArray.java | 211 ++++++
.../collections/lists/DoubleArrayList.java | 164 +++++
.../utils/collections/lists/FloatArrayList.java | 162 +++++
.../utils/collections/lists/IntArrayList.java | 179 ++++++
.../utils/collections/lists/LongArrayList.java | 166 +++++
.../maps/Int2FloatOpenHashTable.java | 418 ++++++++++++
.../collections/maps/Int2IntOpenHashTable.java | 414 ++++++++++++
.../collections/maps/Int2LongOpenHashTable.java | 500 ++++++++++++++
.../utils/collections/maps/IntOpenHashMap.java | 467 ++++++++++++++
.../collections/maps/IntOpenHashTable.java | 404 ++++++++++++
.../hivemall/utils/collections/maps/LRUMap.java | 41 ++
.../maps/Long2DoubleOpenHashTable.java | 445 +++++++++++++
.../maps/Long2FloatOpenHashTable.java | 429 ++++++++++++
.../collections/maps/Long2IntOpenHashTable.java | 473 ++++++++++++++
.../utils/collections/maps/OpenHashMap.java | 351 ++++++++++
.../utils/collections/maps/OpenHashTable.java | 413 ++++++++++++
.../utils/collections/sets/IntArraySet.java | 88 +++
.../hivemall/utils/collections/sets/IntSet.java | 38 ++
.../java/hivemall/utils/hadoop/HiveUtils.java | 69 +-
.../java/hivemall/utils/lang/ArrayUtils.java | 407 +++++++++++-
.../java/hivemall/utils/lang/Primitives.java | 28 +
.../java/hivemall/utils/math/MathUtils.java | 22 +-
.../java/hivemall/utils/math/MatrixUtils.java | 2 +-
.../utils/sampling/IntReservoirSampler.java | 99 +++
.../utils/sampling/ReservoirSampler.java | 100 +++
.../java/hivemall/utils/stream/IntIterator.java | 27 +
.../java/hivemall/utils/stream/IntStream.java | 28 +
.../java/hivemall/utils/stream/StreamUtils.java | 180 ++++++
.../hivemall/fm/FFMPredictionModelTest.java | 2 +-
.../hivemall/math/matrix/MatrixBuilderTest.java | 644 +++++++++++++++++++
.../math/matrix/ints/IntMatrixTest.java | 361 +++++++++++
.../java/hivemall/matrix/MatrixBuilderTest.java | 329 ----------
.../smile/classification/DecisionTreeTest.java | 249 +++----
.../RandomForestClassifierUDTFTest.java | 286 +++++++-
.../smile/regression/RegressionTreeTest.java | 81 ++-
.../smile/tools/TreePredictUDFTest.java | 61 +-
.../hivemall/smile/vm/StackMachineTest.java | 88 ---
.../utils/collections/DoubleArray3DTest.java | 147 -----
.../utils/collections/DoubleArrayTest.java | 60 --
.../collections/Int2FloatOpenHashMapTest.java | 96 ---
.../collections/Int2LongOpenHashMapTest.java | 105 ---
.../utils/collections/IntArrayTest.java | 76 ---
.../utils/collections/IntOpenHashMapTest.java | 73 ---
.../utils/collections/IntOpenHashTableTest.java | 50 --
.../utils/collections/OpenHashMapTest.java | 91 ---
.../utils/collections/OpenHashTableTest.java | 138 ----
.../utils/collections/SparseIntArrayTest.java | 61 --
.../collections/arrays/DoubleArray3DTest.java | 149 +++++
.../collections/arrays/DoubleArrayTest.java | 62 ++
.../utils/collections/arrays/IntArrayTest.java | 79 +++
.../collections/arrays/SparseIntArrayTest.java | 64 ++
.../collections/lists/LongArrayListTest.java | 43 ++
.../maps/Int2FloatOpenHashMapTest.java | 98 +++
.../maps/Int2LongOpenHashMapTest.java | 106 +++
.../collections/maps/IntOpenHashMapTest.java | 75 +++
.../collections/maps/IntOpenHashTableTest.java | 52 ++
.../maps/Long2IntOpenHashMapTest.java | 115 ++++
.../utils/collections/maps/OpenHashMapTest.java | 93 +++
.../collections/maps/OpenHashTableTest.java | 140 ++++
.../hivemall/utils/stream/StreamUtilsTest.java | 86 +++
.../hivemall/classifier/news20-multiclass.gz | Bin 0 -> 396138 bytes
.../apache/spark/sql/hive/GroupedDataEx.scala | 2 +-
.../spark/sql/hive/HivemallOpsSuite.scala | 21 +-
.../spark/sql/hive/HivemallGroupedDataset.scala | 2 +-
.../spark/sql/hive/HivemallOpsSuite.scala | 10 +-
.../spark/sql/hive/HivemallGroupedDataset.scala | 2 +-
.../spark/sql/hive/HivemallOpsSuite.scala | 10 +-
161 files changed, 15287 insertions(+), 8113 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/pom.xml
----------------------------------------------------------------------
diff --git a/core/pom.xml b/core/pom.xml
index bf931ac..d7655f4 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -109,7 +109,7 @@
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
- <version>1.0.3</version>
+ <version>1.0.4</version>
<scope>compile</scope>
<exclusions>
<exclusion>
@@ -130,6 +130,12 @@
<version>3.6.1</version>
<scope>compile</scope>
</dependency>
+ <dependency>
+ <groupId>org.roaringbitmap</groupId>
+ <artifactId>RoaringBitmap</artifactId>
+ <version>[0.6,)</version>
+ <scope>compile</scope>
+ </dependency>
<!-- test scope -->
<dependency>
@@ -198,6 +204,8 @@
<include>com.github.haifengl:smile-math</include>
<include>com.github.haifengl:smile-data</include>
<include>org.tukaani:xz</include>
+ <include>org.apache.commons:commons-math3</include>
+ <include>org.roaringbitmap:RoaringBitmap</include>
</includes>
</artifactSet>
<transformers>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/annotations/Immutable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/annotations/Immutable.java b/core/src/main/java/hivemall/annotations/Immutable.java
new file mode 100644
index 0000000..941fa5d
--- /dev/null
+++ b/core/src/main/java/hivemall/annotations/Immutable.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.annotations;
+
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The class to which this annotation is applied is Immutable.
+ */
+@Documented
+@Target(ElementType.TYPE)
+@Retention(RetentionPolicy.CLASS)
+public @interface Immutable {
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/annotations/Mutable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/annotations/Mutable.java b/core/src/main/java/hivemall/annotations/Mutable.java
new file mode 100644
index 0000000..bdac5d9
--- /dev/null
+++ b/core/src/main/java/hivemall/annotations/Mutable.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.annotations;
+
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The class to which this annotation is applied is Mutable.
+ *
+ * @see javax.annotation.concurrent.Immutable
+ */
+@Documented
+@Target(ElementType.TYPE)
+@Retention(RetentionPolicy.CLASS)
+public @interface Mutable {
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java
index 7cb7a58..8534231 100644
--- a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java
+++ b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java
@@ -24,8 +24,8 @@ import hivemall.common.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
-import hivemall.utils.collections.Int2FloatOpenHashTable;
-import hivemall.utils.collections.Int2FloatOpenHashTable.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable.IMapIterator;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/common/ReservoirSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/common/ReservoirSampler.java b/core/src/main/java/hivemall/common/ReservoirSampler.java
deleted file mode 100644
index 8846ac1..0000000
--- a/core/src/main/java/hivemall/common/ReservoirSampler.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.common;
-
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
-
-/**
- * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n
- * items.
- *
- * @link http://en.wikipedia.org/wiki/Reservoir_sampling
- * @link http://portal.acm.org/citation.cfm?id=3165
- */
-public final class ReservoirSampler<T> {
-
- private final T[] samples;
- private final int numSamples;
- private int position;
-
- private final Random rand;
-
- @SuppressWarnings("unchecked")
- public ReservoirSampler(int sampleSize) {
- if (sampleSize <= 0) {
- throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize);
- }
- this.samples = (T[]) new Object[sampleSize];
- this.numSamples = sampleSize;
- this.position = 0;
- this.rand = new Random();
- }
-
- @SuppressWarnings("unchecked")
- public ReservoirSampler(int sampleSize, long seed) {
- this.samples = (T[]) new Object[sampleSize];
- this.numSamples = sampleSize;
- this.position = 0;
- this.rand = new Random(seed);
- }
-
- public ReservoirSampler(T[] samples) {
- this.samples = samples;
- this.numSamples = samples.length;
- this.position = 0;
- this.rand = new Random();
- }
-
- public ReservoirSampler(T[] samples, long seed) {
- this.samples = samples;
- this.numSamples = samples.length;
- this.position = 0;
- this.rand = new Random(seed);
- }
-
- public T[] getSample() {
- return samples;
- }
-
- public List<T> getSamplesAsList() {
- return Arrays.asList(samples);
- }
-
- public void add(T item) {
- if (item == null) {
- return;
- }
- if (position < numSamples) {// reservoir not yet full, just append
- samples[position] = item;
- } else {// find a item to replace
- int replaceIndex = rand.nextInt(position + 1);
- if (replaceIndex < numSamples) {
- samples[replaceIndex] = item;
- }
- }
- position++;
- }
-
- public void clear() {
- Arrays.fill(samples, null);
- this.position = 0;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FFMPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java
index 6969d05..befbec9 100644
--- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java
+++ b/core/src/main/java/hivemall/fm/FFMPredictionModel.java
@@ -21,8 +21,8 @@ package hivemall.fm;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.codec.VariableByteCodec;
import hivemall.utils.codec.ZigZagLEB128Codec;
-import hivemall.utils.collections.Int2LongOpenHashTable;
-import hivemall.utils.collections.IntOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable;
import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ArrayUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
index 4009326..4f445fa 100644
--- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
@@ -22,7 +22,7 @@ import hivemall.fm.Entry.AdaGradEntry;
import hivemall.fm.Entry.FTRLEntry;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
index d2a5ed6..19ac287 100644
--- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
@@ -18,8 +18,8 @@
*/
package hivemall.fm;
-import hivemall.utils.collections.Int2FloatOpenHashTable;
-import hivemall.utils.collections.IntOpenHashMap;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashMap;
import java.util.Arrays;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
index 10ffaae..cd99046 100644
--- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
@@ -19,7 +19,7 @@
package hivemall.fm;
import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.collections.OpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
index e63797c..76bead8 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
@@ -19,8 +19,8 @@
package hivemall.fm;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
-import hivemall.utils.collections.DoubleArray3D;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.arrays.DoubleArray3D;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import java.util.Arrays;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index fe27269..67dbf87 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -19,8 +19,8 @@
package hivemall.fm;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
-import hivemall.utils.collections.DoubleArray3D;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.arrays.DoubleArray3D;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.Text3;
import hivemall.utils.lang.NumberUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java b/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java
index 8e84bd8..ab418ed 100644
--- a/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java
@@ -19,7 +19,7 @@
package hivemall.ftvec.ranking;
import hivemall.UDTFWithOptions;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java
index b5afb99..94bb697 100644
--- a/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java
+++ b/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java
@@ -18,7 +18,7 @@
*/
package hivemall.ftvec.ranking;
-import hivemall.utils.collections.IntArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.ArrayUtils;
import java.util.Random;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
index 908a0b7..5e9f797 100644
--- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
+++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
@@ -18,9 +18,9 @@
*/
package hivemall.ftvec.ranking;
-import hivemall.utils.collections.IntArrayList;
-import hivemall.utils.collections.IntOpenHashMap;
-import hivemall.utils.collections.IntOpenHashMap.IMapIterator;
+import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashMap.IMapIterator;
import java.util.BitSet;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
new file mode 100644
index 0000000..2ee27f7
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class AbstractMatrix implements Matrix {
+
+ public AbstractMatrix() {}
+
+ @Override
+ public double[] row() {
+ int cols = numColumns();
+ return new double[cols];
+ }
+
+ @Override
+ public Vector rowVector() {
+ return new SparseVector();
+ }
+
+ @Override
+ public final double get(@Nonnegative final int row, @Nonnegative final int col) {
+ return get(row, col, 0.d);
+ }
+
+ protected static final void checkRowIndex(final int row, final int numRows) {
+ if (row < 0 || row >= numRows) {
+ throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
+ }
+ }
+
+ protected static final void checkColIndex(final int col, final int numColumns) {
+ if (col < 0 || col >= numColumns) {
+ throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
+ }
+ }
+
+ protected static final void checkIndex(final int index) {
+ if (index < 0) {
+ throw new IndexOutOfBoundsException("Invalid index " + index);
+ }
+ }
+
+ protected static final void checkIndex(final int row, final int col) {
+ if (row < 0) {
+ throw new IndexOutOfBoundsException("Invalid row index " + row);
+ }
+ if (col < 0) {
+ throw new IndexOutOfBoundsException("Invalid col index " + col);
+ }
+ }
+
+ protected static final void checkIndex(final int row, final int col, final int numRows,
+ final int numColumns) {
+ if (row < 0 || row >= numRows) {
+ throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
+ }
+ if (col < 0 || col >= numColumns) {
+ throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
+ }
+ }
+
+ @Override
+ public void eachInRow(final int row, @Nonnull final VectorProcedure procedure) {
+ eachInRow(row, procedure, true);
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ eachInColumn(col, procedure, true);
+ }
+
+ @Override
+ public void eachNonNullInRow(final int row, @Nonnull final VectorProcedure procedure) {
+ eachInRow(row, procedure, false);
+ }
+
+ @Override
+ public void eachNonNullInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ eachInColumn(col, procedure, false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java b/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java
new file mode 100644
index 0000000..51c80aa
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.vector.VectorProcedure;
+
+public abstract class ColumnMajorMatrix extends AbstractMatrix {
+
+ public ColumnMajorMatrix() {
+ super();
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return true;
+ }
+
+ @Override
+ public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachNonZeroInRow(int row, VectorProcedure procedure) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ColumnMajorMatrix toColumnMajorMatrix() {
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/Matrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/Matrix.java b/core/src/main/java/hivemall/math/matrix/Matrix.java
new file mode 100644
index 0000000..8a4782a
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/Matrix.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * Double matrix.
+ */
+@NotThreadSafe
+public interface Matrix {
+
+ public boolean isSparse();
+
+ public boolean isRowMajorMatrix();
+
+ public boolean isColumnMajorMatrix();
+
+ public boolean readOnly();
+
+ public boolean swappable();
+
+ /** The Number of Non-Zeros */
+ public int nnz();
+
+ @Nonnegative
+ public int numRows();
+
+ @Nonnegative
+ public int numColumns();
+
+ @Nonnegative
+ public int numColumns(@Nonnegative int row);
+
+ @Nonnull
+ public double[] row();
+
+ @Nonnull
+ public Vector rowVector();
+
+ @Nonnull
+ public double[] getRow(@Nonnegative int index);
+
+ /**
+ * @return returns dst
+ */
+ @Nonnull
+ public double[] getRow(@Nonnegative int index, @Nonnull double[] dst);
+
+ public void getRow(@Nonnegative int index, @Nonnull Vector row);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public double get(@Nonnegative int row, @Nonnegative int col);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ */
+ public double get(@Nonnegative int row, @Nonnegative int col, double defaultValue);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public void set(@Nonnegative int row, @Nonnegative int col, double value);
+
+ /**
+ * @throws IndexOutOfBoundsException
+ * @throws UnsupportedOperationException
+ */
+ public double getAndSet(@Nonnegative int row, @Nonnegative int col, double value);
+
+ public void swap(@Nonnegative int row1, @Nonnegative int row2);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachColumnIndexInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure,
+ boolean nullOutput);
+
+ public void eachNonNullInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure);
+
+ @Nonnull
+ public RowMajorMatrix toRowMajorMatrix();
+
+ @Nonnull
+ public ColumnMajorMatrix toColumnMajorMatrix();
+
+ @Nonnull
+ public MatrixBuilder builder();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
new file mode 100644
index 0000000..90ce78f
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.matrix.builders.MatrixBuilder;
+import hivemall.math.matrix.ints.IntMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class MatrixUtils {
+
+ private MatrixUtils() {}
+
+ @Nonnull
+ public static Matrix shuffle(@Nonnull final Matrix m, @Nonnull final int[] indices) {
+ Preconditions.checkArgument(m.numRows() <= indices.length, "m.numRow() `" + m.numRows()
+ + "` MUST be equals to or less than |swapIndicies| `" + indices.length + "`");
+
+ final MatrixBuilder builder = m.builder();
+ final VectorProcedure proc = new VectorProcedure() {
+ public void apply(int col, double value) {
+ builder.nextColumn(col, value);
+ }
+ };
+ for (int i = 0; i < indices.length; i++) {
+ int idx = indices[i];
+ m.eachNonNullInRow(idx, proc);
+ builder.nextRow();
+ }
+ return builder.buildMatrix();
+ }
+
+ /**
+ * Returns the index of maximum value of an array.
+ *
+ * @return -1 if there are no columns
+ */
+ public static int whichMax(@Nonnull final IntMatrix matrix, @Nonnegative final int row) {
+ final MutableInt m = new MutableInt(Integer.MIN_VALUE);
+ final MutableInt which = new MutableInt(-1);
+ matrix.eachInRow(row, new VectorProcedure() {
+ @Override
+ public void apply(int i, int value) {
+ if (value > m.getValue()) {
+ m.setValue(value);
+ which.setValue(i);
+ }
+ }
+ }, false);
+ return which.getValue();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java b/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java
new file mode 100644
index 0000000..2c611bd
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix;
+
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class RowMajorMatrix extends AbstractMatrix {
+
+ public RowMajorMatrix() {
+ super();
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return true;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ row.clear();
+ eachNonNullInRow(index, new VectorProcedure() {
+ @Override
+ public void apply(final int i, final double value) {
+ row.set(i, value);
+ }
+ });
+ }
+
+ @Override
+ public void eachInColumn(int col, VectorProcedure procedure, boolean nullOutput) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachNonZeroInColumn(int col, VectorProcedure procedure) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RowMajorMatrix toRowMajorMatrix() {
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
new file mode 100644
index 0000000..df2bff7
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.sparse.CSCMatrix;
+import hivemall.utils.collections.lists.DoubleArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class CSCMatrixBuilder extends MatrixBuilder {
+
+ @Nonnull
+ private final IntArrayList rows;
+ @Nonnull
+ private final IntArrayList cols;
+ @Nonnull
+ private final DoubleArrayList values;
+
+ private int row;
+ private int maxNumColumns;
+
+ public CSCMatrixBuilder(int initSize) {
+ super();
+ this.rows = new IntArrayList(initSize);
+ this.cols = new IntArrayList(initSize);
+ this.values = new DoubleArrayList(initSize);
+ this.row = 0;
+ this.maxNumColumns = 0;
+ }
+
+ @Override
+ public CSCMatrixBuilder nextRow() {
+ row++;
+ return this;
+ }
+
+ @Override
+ public CSCMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ rows.add(row);
+ cols.add(col);
+ values.add((float) value);
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
+ return this;
+ }
+
+ @Override
+ public CSCMatrix buildMatrix() {
+ if (rows.isEmpty() || cols.isEmpty()) {
+ throw new IllegalStateException("No element in the matrix");
+ }
+
+ final int[] columnIndices = cols.toArray(true);
+ final int[] rowsIndicies = rows.toArray(true);
+ final double[] valuesArray = values.toArray(true);
+
+ // convert to column major
+ final int nnz = valuesArray.length;
+ SortObj[] sortObjs = new SortObj[nnz];
+ for (int i = 0; i < nnz; i++) {
+ sortObjs[i] = new SortObj(columnIndices[i], rowsIndicies[i], valuesArray[i]);
+ }
+ Arrays.sort(sortObjs);
+ for (int i = 0; i < nnz; i++) {
+ columnIndices[i] = sortObjs[i].columnIndex;
+ rowsIndicies[i] = sortObjs[i].rowsIndex;
+ valuesArray[i] = sortObjs[i].value;
+ }
+ sortObjs = null;
+
+ final int[] columnPointers = new int[maxNumColumns + 1];
+ int prevCol = -1;
+ for (int j = 0; j < columnIndices.length; j++) {
+ int currCol = columnIndices[j];
+ if (currCol != prevCol) {
+ columnPointers[currCol] = j;
+ prevCol = currCol;
+ }
+ }
+ columnPointers[maxNumColumns] = nnz; // nnz
+
+ return new CSCMatrix(columnPointers, rowsIndicies, valuesArray, row, maxNumColumns);
+ }
+
+ private static final class SortObj implements Comparable<SortObj> {
+ final int columnIndex;
+ final int rowsIndex;
+ final double value;
+
+ SortObj(int columnIndex, int rowsIndex, double value) {
+ this.columnIndex = columnIndex;
+ this.rowsIndex = rowsIndex;
+ this.value = value;
+ }
+
+ @Override
+ public int compareTo(SortObj o) {
+ return Integer.compare(columnIndex, o.columnIndex);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
new file mode 100644
index 0000000..2467056
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.sparse.CSRMatrix;
+import hivemall.utils.collections.lists.DoubleArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Compressed Sparse Row Matrix builder.
+ */
+public final class CSRMatrixBuilder extends MatrixBuilder {
+
+ @Nonnull
+ private final IntArrayList rowPointers;
+ @Nonnull
+ private final IntArrayList columnIndices;
+ @Nonnull
+ private final DoubleArrayList values;
+
+ private int maxNumColumns;
+
+ public CSRMatrixBuilder(@Nonnegative int initSize) {
+ super();
+ this.rowPointers = new IntArrayList(initSize + 1);
+ rowPointers.add(0);
+ this.columnIndices = new IntArrayList(initSize);
+ this.values = new DoubleArrayList(initSize);
+ this.maxNumColumns = 0;
+ }
+
+ @Override
+ public CSRMatrixBuilder nextRow() {
+ int ptr = values.size();
+ rowPointers.add(ptr);
+ return this;
+ }
+
+ @Override
+ public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) {
+ if (value == 0.d) {
+ return this;
+ }
+
+ columnIndices.add(col);
+ values.add(value);
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
+ return this;
+ }
+
+ @Override
+ public CSRMatrix buildMatrix() {
+ CSRMatrix matrix = new CSRMatrix(rowPointers.toArray(true), columnIndices.toArray(true),
+ values.toArray(true), maxNumColumns);
+ return matrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
new file mode 100644
index 0000000..9cae1c7
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d;
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+import hivemall.utils.collections.maps.IntOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder {
+
+ @Nonnull
+ private final IntOpenHashTable<SparseDoubleArray> col2rows;
+ private int row;
+ private int maxNumColumns;
+ private int nnz;
+
+ public ColumnMajorDenseMatrixBuilder(int initSize) {
+ this.col2rows = new IntOpenHashTable<SparseDoubleArray>(initSize);
+ this.row = 0;
+ this.maxNumColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public ColumnMajorDenseMatrixBuilder nextRow() {
+ row++;
+ return this;
+ }
+
+ @Override
+ public ColumnMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ if (value == 0.d) {
+ return this;
+ }
+
+ SparseDoubleArray rows = col2rows.get(col);
+ if (rows == null) {
+ rows = new SparseDoubleArray(4);
+ col2rows.put(col, rows);
+ }
+ rows.put(row, value);
+ this.maxNumColumns = Math.max(col + 1, maxNumColumns);
+ nnz++;
+ return this;
+ }
+
+ @Override
+ public ColumnMajorDenseMatrix2d buildMatrix() {
+ final double[][] data = new double[maxNumColumns][];
+
+ final IMapIterator<SparseDoubleArray> itor = col2rows.entries();
+ while (itor.next() != -1) {
+ int col = itor.getKey();
+ SparseDoubleArray rows = itor.getValue();
+ data[col] = rows.toArray();
+ }
+
+ return new ColumnMajorDenseMatrix2d(data, row, nnz);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
new file mode 100644
index 0000000..556a8d8
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.sparse.DoKMatrix;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class DoKMatrixBuilder extends MatrixBuilder {
+
+ @Nonnull
+ private final DoKMatrix matrix;
+
+ private int row;
+
+ public DoKMatrixBuilder(@Nonnegative int initSize) {
+ super();
+ this.row = 0;
+ this.matrix = new DoKMatrix(initSize);
+ }
+
+ @Override
+ public DoKMatrixBuilder nextRow() {
+ row++;
+ return this;
+ }
+
+ @Override
+ public DoKMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ matrix.set(row, col, value);
+ return this;
+ }
+
+ @Override
+ public DoKMatrix buildMatrix() {
+ return matrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
new file mode 100644
index 0000000..66bd1e2
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.Matrix;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class MatrixBuilder {
+
+ public MatrixBuilder() {}
+
+ public void nextRow(@Nonnull final double[] row) {
+ for (int col = 0; col < row.length; col++) {
+ nextColumn(col, row[col]);
+ }
+ nextRow();
+ }
+
+ public void nextRow(@Nonnull final String[] row) {
+ for (String col : row) {
+ if (col == null) {
+ continue;
+ }
+ nextColumn(col);
+ }
+ nextRow();
+ }
+
+ @Nonnull
+ public abstract MatrixBuilder nextRow();
+
+ @Nonnull
+ public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value);
+
+ /**
+ * @throws IllegalArgumentException
+ * @throws NumberFormatException
+ */
+ @Nonnull
+ public MatrixBuilder nextColumn(@Nonnull final String col) {
+ final int pos = col.indexOf(':');
+ if (pos == 0) {
+ throw new IllegalArgumentException("Invalid feature value representation: " + col);
+ }
+
+ final String feature;
+ final double value;
+ if (pos > 0) {
+ feature = col.substring(0, pos);
+ String s2 = col.substring(pos + 1);
+ value = Double.parseDouble(s2);
+ } else {
+ feature = col;
+ value = 1.d;
+ }
+
+ if (feature.indexOf(':') != -1) {
+ throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col);
+ }
+
+ int colIndex = Integer.parseInt(feature);
+ if (colIndex < 0) {
+ throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: "
+ + colIndex);
+ }
+
+ return nextColumn(colIndex, value);
+ }
+
+ @Nonnull
+ public abstract Matrix buildMatrix();
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
new file mode 100644
index 0000000..b6d0588
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.builders;
+
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public final class RowMajorDenseMatrixBuilder extends MatrixBuilder {
+
+ @Nonnull
+ private final List<double[]> rows;
+ private int maxNumColumns;
+ private int nnz;
+
+ @Nonnull
+ private final SparseDoubleArray rowProbe;
+
+ public RowMajorDenseMatrixBuilder(@Nonnegative int initSize) {
+ super();
+ this.rows = new ArrayList<double[]>(initSize);
+ this.maxNumColumns = 0;
+ this.nnz = 0;
+ this.rowProbe = new SparseDoubleArray(32);
+ }
+
+ @Override
+ public RowMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) {
+ if (value == 0.d) {
+ return this;
+ }
+ rowProbe.put(col, value);
+ nnz++;
+ return this;
+ }
+
+ @Override
+ public RowMajorDenseMatrixBuilder nextRow() {
+ double[] row = rowProbe.toArray();
+ rowProbe.clear();
+ nextRow(row);
+ return this;
+ }
+
+ @Override
+ public void nextRow(@Nonnull double[] row) {
+ rows.add(row);
+ this.maxNumColumns = Math.max(row.length, maxNumColumns);
+ }
+
+ @Override
+ public RowMajorDenseMatrix2d buildMatrix() {
+ int numRows = rows.size();
+ double[][] data = rows.toArray(new double[numRows][]);
+ return new RowMajorDenseMatrix2d(data, maxNumColumns, nnz);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java b/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java
new file mode 100644
index 0000000..2c5fd45
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.dense;
+
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Fixed-size Dense 2-d double Matrix.
+ */
+public final class ColumnMajorDenseMatrix2d extends ColumnMajorMatrix {
+
+ @Nonnull
+ private final double[][] data; // col-row
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public ColumnMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numRows) {
+ this(data, numRows, nnz(data));
+ }
+
+ public ColumnMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numRows,
+ @Nonnegative int nnz) {
+ super();
+ this.data = data;
+ this.numRows = numRows;
+ this.numColumns = data.length;
+ this.nnz = nnz;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return false;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(final int row) {
+ checkRowIndex(row, numRows);
+
+ int numColumns = 0;
+ for (int j = 0; j < data.length; j++) {
+ final double[] col = data[j];
+ if (col == null) {
+ continue;
+ }
+ if (row < col.length && col[row] != 0.d) {
+ numColumns++;
+ }
+ }
+ return numColumns;
+ }
+
+ @Override
+ public double[] getRow(final int index) {
+ checkRowIndex(index, numRows);
+
+ double[] row = new double[numColumns];
+ return getRow(index, row);
+ }
+
+ @Override
+ public double[] getRow(final int index, @Nonnull final double[] dst) {
+ checkRowIndex(index, numRows);
+
+ for (int j = 0; j < data.length; j++) {
+ final double[] col = data[j];
+ if (col == null) {
+ continue;
+ }
+ if (index < col.length) {
+ dst[j] = col[index];
+ }
+ }
+ return dst;
+ }
+
+ @Override
+ public void getRow(final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int j = 0; j < data.length; j++) {
+ final double[] col = data[j];
+ if (col == null) {
+ continue;
+ }
+ if (index < col.length) {
+ double v = col[index];
+ row.set(j, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(final int row, final int col, final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final double[] colData = data[col];
+ if (colData == null || row >= colData.length) {
+ return defaultValue;
+ }
+ return colData[row];
+ }
+
+ @Override
+ public double getAndSet(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final double[] colData = data[col];
+ Preconditions.checkNotNull(colData, "col does not exists: " + col);
+ checkRowIndex(row, colData.length);
+
+ final double old = colData[row];
+ colData[row] = value;
+ if (old == 0.d && value != 0.d) {
+ ++nnz;
+ }
+ return old;
+ }
+
+ @Override
+ public void set(final int row, final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+ if (value == 0.d) {
+ return;
+ }
+
+ final double[] colData = data[col];
+ Preconditions.checkNotNull(colData, "col does not exists: " + col);
+ checkRowIndex(row, colData.length);
+
+ if (colData[row] == 0.d) {
+ ++nnz;
+ }
+ colData[row] = value;
+ }
+
+ @Override
+ public void swap(int row1, int row2) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ final double[] colData = data[col];
+ if (colData == null) {
+ if (nullOutput) {
+ for (int i = 0; i < numRows; i++) {
+ procedure.apply(i, 0.d);
+ }
+ }
+ return;
+ }
+
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ procedure.apply(row, colData[row]);
+ }
+ if (nullOutput) {
+ for (; row < numRows; row++) {
+ procedure.apply(row, 0.d);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ final double[] colData = data[col];
+ if (colData == null) {
+ return;
+ }
+ int row = 0;
+ for (int len = colData.length; row < len; row++) {
+ final double v = colData[row];
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public RowMajorDenseMatrix2d toRowMajorMatrix() {
+ final double[][] rowcol = new double[numRows][numColumns];
+ int nnz = 0;
+ for (int j = 0; j < data.length; j++) {
+ final double[] colData = data[j];
+ if (colData == null) {
+ continue;
+ }
+ for (int i = 0; i < colData.length; i++) {
+ final double v = colData[i];
+ if (v == 0.d) {
+ continue;
+ }
+ rowcol[i][j] = v;
+ nnz++;
+ }
+ }
+ for (int i = 0; i < rowcol.length; i++) {
+ final double[] row = rowcol[i];
+ final int last = numColumns - 1;
+ int maxj = last;
+ for (; maxj >= 0; maxj--) {
+ if (row[maxj] != 0.d) {
+ break;
+ }
+ }
+ if (maxj == last) {
+ continue;
+ } else if (maxj < 0) {
+ rowcol[i] = null;
+ continue;
+ }
+ final double[] dstRow = new double[maxj + 1];
+ System.arraycopy(row, 0, dstRow, 0, dstRow.length);
+ rowcol[i] = dstRow;
+ }
+
+ return new RowMajorDenseMatrix2d(rowcol, numColumns, nnz);
+ }
+
+ @Override
+ public ColumnMajorDenseMatrixBuilder builder() {
+ return new ColumnMajorDenseMatrixBuilder(numColumns);
+ }
+
+ private static int nnz(@Nonnull final double[][] data) {
+ int count = 0;
+ for (int j = 0; j < data.length; j++) {
+ final double[] col = data[j];
+ if (col == null) {
+ continue;
+ }
+ for (int i = 0; i < col.length; i++) {
+ if (col[i] != 0.d) {
+ ++count;
+ }
+ }
+ }
+ return count;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java b/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java
new file mode 100644
index 0000000..54302e1
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.dense;
+
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Fixed-size Dense 2-d double Matrix.
+ */
+public final class RowMajorDenseMatrix2d extends RowMajorMatrix {
+
+ @Nonnull
+ private final double[][] data;
+
+ @Nonnegative
+ private final int numRows;
+ @Nonnegative
+ private final int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public RowMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) {
+ this(data, numColumns, nnz(data));
+ }
+
+ public RowMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns,
+ @Nonnegative int nnz) {
+ super();
+ this.data = data;
+ this.numRows = data.length;
+ this.numColumns = numColumns;
+ this.nnz = nnz;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return true;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ checkRowIndex(row, numRows);
+
+ final double[] r = data[row];
+ if (r == null) {
+ return 0;
+ }
+ return r.length;
+ }
+
+ @Override
+ public DenseVector rowVector() {
+ return new DenseVector(numColumns);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ checkRowIndex(index, numRows);
+
+ final double[] row = data[index];
+ if (row == null) {
+ return new double[0];
+ } else if (row.length == numRows) {
+ return row;
+ }
+
+ final double[] result = new double[numRows];
+ System.arraycopy(row, 0, result, 0, row.length);
+ return result;
+ }
+
+ @Override
+ public double[] getRow(@Nonnull final int index, @Nonnull final double[] dst) {
+ checkRowIndex(index, numRows);
+
+ final double[] row = data[index];
+ if (row == null) {
+ return new double[0];
+ }
+
+ System.arraycopy(row, 0, dst, 0, row.length);
+ if (dst.length > row.length) {// zerofill
+ Arrays.fill(dst, row.length, dst.length, 0.d);
+ }
+ return dst;
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final double[] rowData = data[row];
+ if (rowData == null || col >= rowData.length) {
+ return defaultValue;
+ }
+ return rowData[col];
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ checkIndex(row, col, numRows, numColumns);
+
+ final double[] rowData = data[row];
+ Preconditions.checkNotNull(rowData, "row does not exists: " + row);
+ checkColIndex(col, rowData.length);
+
+ double old = rowData[col];
+ rowData[col] = value;
+ if (old == 0.d && value != 0.d) {
+ ++nnz;
+ }
+ return old;
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ checkIndex(row, col, numRows, numColumns);
+ if (value == 0.d) {
+ return;
+ }
+
+ final double[] rowData = data[row];
+ Preconditions.checkNotNull(rowData, "row does not exists: " + row);
+ checkColIndex(col, rowData.length);
+
+ if (rowData[col] == 0.d) {
+ ++nnz;
+ }
+ rowData[col] = value;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ double[] oldRow1 = data[row1];
+ data[row1] = data[row2];
+ data[row2] = oldRow1;
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ final double[] rowData = data[row];
+ if (rowData == null) {
+ if (nullOutput) {
+ for (int j = 0; j < numColumns; j++) {
+ procedure.apply(j, 0.d);
+ }
+ }
+ return;
+ }
+
+ int col = 0;
+ for (int len = rowData.length; col < len; col++) {
+ procedure.apply(col, rowData[col]);
+ }
+ if (nullOutput) {
+ for (; col < numColumns; col++) {
+ procedure.apply(col, 0.d);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final double[] rowData = data[row];
+ if (rowData == null) {
+ return;
+ }
+ for (int col = 0, len = rowData.length; col < len; col++) {
+ final double v = rowData[col];
+ if (v != 0.d) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ final double[] rowData = data[row];
+ if (rowData == null) {
+ return;
+ }
+ for (int col = 0, len = rowData.length; col < len; col++) {
+ procedure.apply(col);
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ final double[] rowData = data[row];
+ if (rowData != null && col < rowData.length) {
+ procedure.apply(row, rowData[col]);
+ } else {
+ if (nullOutput) {
+ procedure.apply(row, 0.d);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ final double[] rowData = data[row];
+ if (rowData == null) {
+ continue;
+ }
+ if (col < rowData.length) {
+ final double v = rowData[col];
+ if (v != 0.d) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ColumnMajorDenseMatrix2d toColumnMajorMatrix() {
+ final double[][] colrow = new double[numColumns][numRows];
+ int nnz = 0;
+ for (int i = 0; i < data.length; i++) {
+ final double[] rowData = data[i];
+ if (rowData == null) {
+ continue;
+ }
+ for (int j = 0; j < rowData.length; j++) {
+ final double v = rowData[j];
+ if (v == 0.d) {
+ continue;
+ }
+ colrow[j][i] = v;
+ nnz++;
+ }
+ }
+ for (int j = 0; j < colrow.length; j++) {
+ final double[] col = colrow[j];
+ final int last = numRows - 1;
+ int maxi = last;
+ for (; maxi >= 0; maxi--) {
+ if (col[maxi] != 0.d) {
+ break;
+ }
+ }
+ if (maxi == last) {
+ continue;
+ } else if (maxi < 0) {
+ colrow[j] = null;
+ continue;
+ }
+ final double[] dstCol = new double[maxi + 1];
+ System.arraycopy(col, 0, dstCol, 0, dstCol.length);
+ colrow[j] = dstCol;
+ }
+
+ return new ColumnMajorDenseMatrix2d(colrow, numRows, nnz);
+ }
+
+ @Override
+ public RowMajorDenseMatrixBuilder builder() {
+ return new RowMajorDenseMatrixBuilder(numRows);
+ }
+
+ private static int nnz(@Nonnull final double[][] data) {
+ int count = 0;
+ for (int i = 0; i < data.length; i++) {
+ final double[] row = data[i];
+ if (row == null) {
+ continue;
+ }
+ for (int j = 0; j < row.length; j++) {
+ if (row[j] != 0.d) {
+ ++count;
+ }
+ }
+ }
+ return count;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
new file mode 100644
index 0000000..0431310
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.ints;
+
+import hivemall.math.vector.VectorProcedure;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class AbstractIntMatrix implements IntMatrix {
+
+ protected int defaultValue;
+
+ public AbstractIntMatrix() {
+ this.defaultValue = 0;
+ }
+
+ @Override
+ public void setDefaultValue(int value) {
+ this.defaultValue = value;
+ }
+
+ @Override
+ public int[] row() {
+ int size = numRows();
+ return new int[size];
+ }
+
+ @Override
+ public final int get(@Nonnegative final int row, @Nonnegative final int col) {
+ return get(row, col, defaultValue);
+ }
+
+ @Override
+ public void incr(@Nonnegative final int row, @Nonnegative final int col) {
+ incr(row, col, 1);
+ }
+
+ protected static final void checkRowIndex(final int row, final int numRows) {
+ if (row < 0 || row >= numRows) {
+ throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
+ }
+ }
+
+ protected static final void checkColIndex(final int col, final int numColumns) {
+ if (col < 0 || col >= numColumns) {
+ throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
+ }
+ }
+
+ protected static final void checkIndex(final int index) {
+ if (index < 0) {
+ throw new IllegalArgumentException("Invalid index: " + index);
+ }
+ }
+
+ protected static final void checkIndex(final int row, final int col) {
+ if (row < 0) {
+ throw new IllegalArgumentException("Invalid row index: " + row);
+ }
+ if (col < 0) {
+ throw new IllegalArgumentException("Invalid col index: " + col);
+ }
+ }
+
+ protected static final void checkIndex(final int row, final int col, final int numRows,
+ final int numColumns) {
+ if (row < 0 || row >= numRows) {
+ throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows);
+ }
+ if (col < 0 || col >= numColumns) {
+ throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns);
+ }
+ }
+
+ @Override
+ public void eachInRow(final int row, @Nonnull final VectorProcedure procedure) {
+ eachInRow(row, procedure, true);
+ }
+
+ @Override
+ public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ eachInColumn(col, procedure, true);
+ }
+
+ @Override
+ public void eachNonNullInRow(final int row, @Nonnull final VectorProcedure procedure) {
+ eachInRow(row, procedure, false);
+ }
+
+ @Override
+ public void eachNonNullInColumn(final int col, @Nonnull final VectorProcedure procedure) {
+ eachInColumn(col, procedure, false);
+ }
+
+}
[08/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
index 7fd841a..40957cb 100644
--- a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
+++ b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
@@ -18,127 +18,289 @@
*/
package hivemall.smile.tools;
-import hivemall.utils.collections.IntArrayList;
-import hivemall.utils.lang.Counter;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.SizeOf;
-import java.util.Arrays;
+import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDAF;
-import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.IntWritable;
+
+@Description(
+ name = "rf_ensemble",
+ value = "_FUNC_(int yhat, array<double> proba [, double model_weight=1.0])"
+ + " - Returns emsebled prediction results in <int label, double probability, array<double> probabilities>")
+public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver {
+
+ public RandomForestEnsembleUDAF() {
+ super();
+ }
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 2 && typeInfo.length != 3) {
+ throw new UDFArgumentLengthException("Expected 2 or 3 arguments but got "
+ + typeInfo.length);
+ }
+ if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
+ throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]);
+ }
+ if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) {
+ throw new UDFArgumentTypeException(1, "ARRAY<double> is expected for posteriori: "
+ + typeInfo[1]);
+ }
+ if (typeInfo.length == 3) {
+ if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) {
+ throw new UDFArgumentTypeException(2, "Expected DOUBLE or FLOAT for model_weight: "
+ + typeInfo[2]);
+ }
+ }
+ return new RfEvaluator();
+ }
-@SuppressWarnings("deprecation")
-@Description(name = "rf_ensemble",
- value = "_FUNC_(int y) - Returns emsebled prediction results of Random Forest classifiers")
-public final class RandomForestEnsembleUDAF extends UDAF {
- public static class RandomForestPredictUDAFEvaluator implements UDAFEvaluator {
+ @SuppressWarnings("deprecation")
+ public static final class RfEvaluator extends GenericUDAFEvaluator {
- private Counter<Integer> partial;
+ private PrimitiveObjectInspector yhatOI;
+ private ListObjectInspector posterioriOI;
+ private PrimitiveObjectInspector posterioriElemOI;
+ @Nullable
+ private PrimitiveObjectInspector weightOI;
- @Override
- public void init() {
- this.partial = null;
+ private StructObjectInspector internalMergeOI;
+ private StructField sizeField, posterioriField;
+ private IntObjectInspector sizeFieldOI;
+ private StandardListObjectInspector posterioriFieldOI;
+
+ public RfEvaluator() {
+ super();
}
- public boolean iterate(Integer k) {
- if (k == null) {
- return true;
+ @Override
+ public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] parameters)
+ throws HiveException {
+ super.init(mode, parameters);
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.yhatOI = HiveUtils.asIntegerOI(parameters[0]);
+ this.posterioriOI = HiveUtils.asListOI(parameters[1]);
+ this.posterioriElemOI = HiveUtils.asDoubleCompatibleOI(posterioriOI.getListElementObjectInspector());
+ if (parameters.length == 3) {
+ this.weightOI = HiveUtils.asDoubleCompatibleOI(parameters[2]);
+ }
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.sizeField = soi.getStructFieldRef("size");
+ this.posterioriField = soi.getStructFieldRef("posteriori");
+ this.sizeFieldOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+ this.posterioriFieldOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
- if (partial == null) {
- this.partial = new Counter<Integer>();
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ List<String> fieldNames = new ArrayList<>(3);
+ List<ObjectInspector> fieldOIs = new ArrayList<>(3);
+ fieldNames.add("size");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("posteriori");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
+ fieldOIs);
+ } else {// terminate
+ List<String> fieldNames = new ArrayList<>(3);
+ List<ObjectInspector> fieldOIs = new ArrayList<>(3);
+ fieldNames.add("label");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("probability");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("probabilities");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
+ fieldOIs);
}
- partial.increment(k);
- return true;
+ return outputOI;
}
- /*
- * https://cwiki.apache.org/confluence/display/Hive/GenericUDAFCaseStudy#GenericUDAFCaseStudy-terminatePartial
- */
- public Map<Integer, Integer> terminatePartial() {
- if (partial == null) {
- return null;
+ @Override
+ public RfAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ RfAggregationBuffer buf = new RfAggregationBuffer();
+ reset(buf);
+ return buf;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ RfAggregationBuffer buf = (RfAggregationBuffer) agg;
+ buf.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ RfAggregationBuffer buf = (RfAggregationBuffer) agg;
+
+ Preconditions.checkNotNull(parameters[0]);
+ int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI);
+ Preconditions.checkNotNull(parameters[1]);
+ double[] posteriori = HiveUtils.asDoubleArray(parameters[1], posterioriOI,
+ posterioriElemOI);
+
+ double weight = 1.0d;
+ if (parameters.length == 3) {
+ Preconditions.checkNotNull(parameters[2]);
+ weight = PrimitiveObjectInspectorUtils.getDouble(parameters[2], weightOI);
}
- if (partial.size() == 0) {
+ buf.iterate(yhat, weight, posteriori);
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ RfAggregationBuffer buf = (RfAggregationBuffer) agg;
+ if (buf._k == -1) {
return null;
- } else {
- return partial.getMap(); // CAN NOT return Counter here
}
+
+ Object[] partial = new Object[2];
+ partial[0] = new IntWritable(buf._k);
+ partial[1] = WritableUtils.toWritableList(buf._posteriori);
+ return partial;
}
- public boolean merge(Map<Integer, Integer> o) {
- if (o == null) {
- return true;
+ @Override
+ public void merge(AggregationBuffer agg, Object partial) throws HiveException {
+ if (partial == null) {
+ return;
}
+ RfAggregationBuffer buf = (RfAggregationBuffer) agg;
- if (partial == null) {
- this.partial = new Counter<Integer>();
+ Object o1 = internalMergeOI.getStructFieldData(partial, sizeField);
+ int size = sizeFieldOI.get(o1);
+ Object posteriori = internalMergeOI.getStructFieldData(partial, posterioriField);
+
+ // --------------------------------------------------------------
+ // [workaround]
+ // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray
+ // cannot be cast to [Ljava.lang.Object;
+ if (posteriori instanceof LazyBinaryArray) {
+ posteriori = ((LazyBinaryArray) posteriori).getList();
}
- partial.addAll(o);
- return true;
+
+ buf.merge(size, posteriori, posterioriFieldOI);
}
- public Result terminate() {
- if (partial == null) {
- return null;
- }
- if (partial.size() == 0) {
+ @Override
+ public Object terminate(AggregationBuffer agg) throws HiveException {
+ RfAggregationBuffer buf = (RfAggregationBuffer) agg;
+ if (buf._k == -1) {
return null;
}
- return new Result(partial);
+ double[] posteriori = buf._posteriori;
+ int label = smile.math.Math.whichMax(posteriori);
+ smile.math.Math.unitize1(posteriori);
+ double proba = posteriori[label];
+
+ Object[] result = new Object[3];
+ result[0] = new IntWritable(label);
+ result[1] = new DoubleWritable(proba);
+ result[2] = WritableUtils.toWritableList(posteriori);
+ return result;
}
+
}
- public static final class Result {
- @SuppressWarnings("unused")
- private Integer label;
- @SuppressWarnings("unused")
- private Double probability;
- @SuppressWarnings("unused")
- private List<Double> probabilities;
-
- Result(Counter<Integer> partial) {
- final Map<Integer, Integer> counts = partial.getMap();
- int size = counts.size();
- assert (size > 0) : size;
- IntArrayList keyList = new IntArrayList(size);
-
- long totalCnt = 0L;
- Integer maxKey = null;
- int maxCnt = Integer.MIN_VALUE;
- for (Map.Entry<Integer, Integer> e : counts.entrySet()) {
- Integer key = e.getKey();
- keyList.add(key);
- int cnt = e.getValue().intValue();
- totalCnt += cnt;
- if (cnt >= maxCnt) {
- maxCnt = cnt;
- maxKey = key;
- }
+ public static final class RfAggregationBuffer extends AbstractAggregationBuffer {
+
+ @Nullable
+ private double[] _posteriori;
+ private int _k;
+
+ public RfAggregationBuffer() {
+ super();
+ reset();
+ }
+
+ void reset() {
+ this._posteriori = null;
+ this._k = -1;
+ }
+
+ void iterate(final int yhat, final double weight, @Nonnull final double[] posteriori)
+ throws HiveException {
+ if (_posteriori == null) {
+ this._k = posteriori.length;
+ this._posteriori = new double[_k];
}
+ if (yhat >= _k) {
+ throw new HiveException("Predicted class " + yhat + " is out of bounds: " + _k);
+ }
+ if (posteriori.length != _k) {
+ throw new HiveException("Given |posteriori| " + posteriori.length
+ + " is differs from expected one: " + _k);
+ }
+
+ _posteriori[yhat] += (posteriori[yhat] * weight);
+ }
- int[] keyArray = keyList.toArray();
- Arrays.sort(keyArray);
- int last = keyArray[keyArray.length - 1];
+ void merge(int size, @Nonnull Object posterioriObj,
+ @Nonnull StandardListObjectInspector posterioriOI) throws HiveException {
- double totalCnt_d = (double) totalCnt;
- final Double[] probabilities = new Double[Math.max(2, last + 1)];
- for (int i = 0, len = probabilities.length; i < len; i++) {
- final Integer cnt = counts.get(Integer.valueOf(i));
- if (cnt == null) {
- probabilities[i] = Double.valueOf(0d);
+ if (size != _k) {
+ if (_k == -1) {
+ this._k = size;
+ this._posteriori = new double[size];
} else {
- probabilities[i] = Double.valueOf(cnt.intValue() / totalCnt_d);
+ throw new HiveException("Mismatch in the number of elements: _k=" + _k
+ + ", size=" + size);
}
}
- this.label = maxKey;
- this.probability = Double.valueOf(maxCnt / totalCnt_d);
- this.probabilities = Arrays.asList(probabilities);
+
+ final double[] posteriori = _posteriori;
+ final DoubleObjectInspector doubleOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ for (int i = 0, len = _k; i < len; i++) {
+ Object o2 = posterioriOI.getListElement(posterioriObj, i);
+ posteriori[i] += doubleOI.get(o2);
+ }
}
+
+ @Override
+ public int estimate() {
+ if (_k == -1) {
+ return 0;
+ }
+ return SizeOf.INT + _k * SizeOf.DOUBLE;
+ }
+
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index b5a81d4..dc544ae 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -18,31 +18,26 @@
*/
package hivemall.smile.tools;
-import hivemall.smile.ModelType;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
import hivemall.smile.classification.DecisionTree;
+import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
-import hivemall.smile.vm.StackMachine;
-import hivemall.smile.vm.VMRuntimeException;
import hivemall.utils.codec.Base91;
-import hivemall.utils.codec.DeflateCodec;
import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.IOUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
-import java.io.Closeable;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
-import javax.script.Bindings;
-import javax.script.Compilable;
-import javax.script.CompiledScript;
-import javax.script.ScriptEngine;
-import javax.script.ScriptEngineManager;
-import javax.script.ScriptException;
import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
@@ -50,73 +45,75 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
-import org.apache.hadoop.mapred.JobConf;
@Description(
name = "tree_predict",
- value = "_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification])"
+ value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])"
+ " - Returns a prediction result of a random forest")
@UDFType(deterministic = true, stateful = false)
public final class TreePredictUDF extends GenericUDF {
private boolean classification;
- private PrimitiveObjectInspector modelTypeOI;
- private StringObjectInspector stringOI;
+ private StringObjectInspector modelOI;
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
+ private boolean denseInput;
+ @Nullable
+ private Vector featuresProbe;
@Nullable
private transient Evaluator evaluator;
- private boolean support_javascript_eval = true;
-
- @Override
- public void configure(MapredContext context) {
- super.configure(context);
-
- if (context != null) {
- JobConf conf = context.getJobConf();
- String tdJarVersion = conf.get("td.jar.version");
- if (tdJarVersion != null) {
- this.support_javascript_eval = false;
- }
- }
- }
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length != 4 && argOIs.length != 5) {
- throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
+ if (argOIs.length != 3 && argOIs.length != 4) {
+ throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments");
}
- this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]);
- this.stringOI = HiveUtils.asStringOI(argOIs[2]);
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]);
+ this.modelOI = HiveUtils.asStringOI(argOIs[1]);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
this.featureListOI = listOI;
ObjectInspector elemOI = listOI.getListElementObjectInspector();
- this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ } else {
+ throw new UDFArgumentException(
+ "_FUNC_ takes array<double> or array<string> for the second argument: "
+ + listOI.getTypeName());
+ }
boolean classification = false;
- if (argOIs.length == 5) {
- classification = HiveUtils.getConstBoolean(argOIs[4]);
+ if (argOIs.length == 4) {
+ classification = HiveUtils.getConstBoolean(argOIs[3]);
}
this.classification = classification;
if (classification) {
- return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+ List<String> fieldNames = new ArrayList<String>(2);
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(2);
+ fieldNames.add("value");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("posteriori");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
} else {
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
}
@Override
- public Writable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
+ public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
Object arg0 = arguments[0].get();
if (arg0 == null) {
throw new HiveException("ModelId was null");
@@ -125,67 +122,96 @@ public final class TreePredictUDF extends GenericUDF {
String modelId = arg0.toString();
Object arg1 = arguments[1].get();
- int modelTypeId = PrimitiveObjectInspectorUtils.getInt(arg1, modelTypeOI);
- ModelType modelType = ModelType.resolve(modelTypeId);
-
- Object arg2 = arguments[2].get();
- if (arg2 == null) {
+ if (arg1 == null) {
return null;
}
- Text script = stringOI.getPrimitiveWritableObject(arg2);
+ Text model = modelOI.getPrimitiveWritableObject(arg1);
- Object arg3 = arguments[3].get();
- if (arg3 == null) {
+ Object arg2 = arguments[2].get();
+ if (arg2 == null) {
throw new HiveException("array<double> features was null");
}
- double[] features = HiveUtils.asDoubleArray(arg3, featureListOI, featureElemOI);
+ this.featuresProbe = parseFeatures(arg2, featuresProbe);
if (evaluator == null) {
- this.evaluator = getEvaluator(modelType, support_javascript_eval);
+ this.evaluator = classification ? new ClassificationEvaluator()
+ : new RegressionEvaluator();
}
-
- Writable result = evaluator.evaluate(modelId, modelType.isCompressed(), script, features,
- classification);
- return result;
+ return evaluator.evaluate(modelId, model, featuresProbe);
}
@Nonnull
- private static Evaluator getEvaluator(@Nonnull ModelType type, boolean supportJavascriptEval)
+ private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector probe)
throws UDFArgumentException {
- final Evaluator evaluator;
- switch (type) {
- case serialization:
- case serialization_compressed: {
- evaluator = new JavaSerializationEvaluator();
- break;
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ if (probe == null) {
+ probe = new DenseVector(length);
+ } else if (length != probe.size()) {
+ probe = new DenseVector(length);
}
- case opscode:
- case opscode_compressed: {
- evaluator = new StackmachineEvaluator();
- break;
+
+ for (int i = 0; i < length; i++) {
+ final Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ probe.set(i, 0.d);
+ } else {
+ double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
+ probe.set(i, v);
+ }
}
- case javascript:
- case javascript_compressed: {
- if (!supportJavascriptEval) {
+ } else {
+ if (probe == null) {
+ probe = new SparseVector();
+ } else {
+ probe.clear();
+ }
+
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String col = o.toString();
+
+ final int pos = col.indexOf(':');
+ if (pos == 0) {
+ throw new UDFArgumentException("Invalid feature value representation: " + col);
+ }
+
+ final String feature;
+ final double value;
+ if (pos > 0) {
+ feature = col.substring(0, pos);
+ String s2 = col.substring(pos + 1);
+ value = Double.parseDouble(s2);
+ } else {
+ feature = col;
+ value = 1.d;
+ }
+
+ if (feature.indexOf(':') != -1) {
+ throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: "
+ + col);
+ }
+
+ final int colIndex = Integer.parseInt(feature);
+ if (colIndex < 0) {
throw new UDFArgumentException(
- "Javascript evaluation is not allowed in Treasure Data env");
+ "Col index MUST be greather than or equals to 0: " + colIndex);
}
- evaluator = new JavascriptEvaluator();
- break;
+ probe.set(colIndex, value);
}
- default:
- throw new UDFArgumentException("Unexpected model type was detected: " + type);
}
- return evaluator;
+ return probe;
}
@Override
public void close() throws IOException {
- this.modelTypeOI = null;
- this.stringOI = null;
+ this.modelOI = null;
this.featureElemOI = null;
this.featureListOI = null;
- IOUtils.closeQuietly(evaluator);
this.evaluator = null;
}
@@ -194,224 +220,81 @@ public final class TreePredictUDF extends GenericUDF {
return "tree_predict(" + Arrays.toString(children) + ")";
}
- public interface Evaluator extends Closeable {
+ interface Evaluator {
- @Nullable
- Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull final Text script,
- @Nonnull final double[] features, final boolean classification)
+ @Nonnull
+ Object evaluate(@Nonnull String modelId, @Nonnull Text model, @Nonnull Vector features)
throws HiveException;
}
- static final class JavaSerializationEvaluator implements Evaluator {
+ static final class ClassificationEvaluator implements Evaluator {
+
+ @Nonnull
+ private final Object[] result;
@Nullable
private String prevModelId = null;
private DecisionTree.Node cNode = null;
- private RegressionTree.Node rNode = null;
-
- JavaSerializationEvaluator() {}
-
- @Override
- public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script,
- double[] features, boolean classification) throws HiveException {
- if (classification) {
- return evaluateClassification(modelId, compressed, script, features);
- } else {
- return evaluteRegression(modelId, compressed, script, features);
- }
- }
- private IntWritable evaluateClassification(@Nonnull String modelId, boolean compressed,
- @Nonnull Text script, double[] features) throws HiveException {
- if (!modelId.equals(prevModelId)) {
- this.prevModelId = modelId;
- int length = script.getLength();
- byte[] b = script.getBytes();
- b = Base91.decode(b, 0, length);
- this.cNode = DecisionTree.deserializeNode(b, b.length, compressed);
- }
- assert (cNode != null);
- int result = cNode.predict(features);
- return new IntWritable(result);
+ ClassificationEvaluator() {
+ this.result = new Object[2];
}
- private DoubleWritable evaluteRegression(@Nonnull String modelId, boolean compressed,
- @Nonnull Text script, double[] features) throws HiveException {
+ @Nonnull
+ public Object[] evaluate(@Nonnull final String modelId, @Nonnull final Text script,
+ @Nonnull final Vector features) throws HiveException {
if (!modelId.equals(prevModelId)) {
this.prevModelId = modelId;
int length = script.getLength();
byte[] b = script.getBytes();
b = Base91.decode(b, 0, length);
- this.rNode = RegressionTree.deserializeNode(b, b.length, compressed);
- }
- assert (rNode != null);
- double result = rNode.predict(features);
- return new DoubleWritable(result);
- }
-
- @Override
- public void close() throws IOException {}
-
- }
-
- static final class StackmachineEvaluator implements Evaluator {
-
- private String prevModelId = null;
- private StackMachine prevVM = null;
- private DeflateCodec codec = null;
-
- StackmachineEvaluator() {}
-
- @Override
- public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script,
- double[] features, boolean classification) throws HiveException {
- final String scriptStr;
- if (compressed) {
- if (codec == null) {
- this.codec = new DeflateCodec(false, true);
- }
- byte[] b = script.getBytes();
- int len = script.getLength();
- b = Base91.decode(b, 0, len);
- try {
- b = codec.decompress(b);
- } catch (IOException e) {
- throw new HiveException("decompression failed", e);
- }
- scriptStr = new String(b);
- } else {
- scriptStr = script.toString();
+ this.cNode = DecisionTree.deserializeNode(b, b.length, true);
}
- final StackMachine vm;
- if (modelId.equals(prevModelId)) {
- vm = prevVM;
- } else {
- vm = new StackMachine();
- try {
- vm.compile(scriptStr);
- } catch (VMRuntimeException e) {
- throw new HiveException("failed to compile StackMachine", e);
+ Arrays.fill(result, null);
+ Preconditions.checkNotNull(cNode);
+ cNode.predict(features, new PredictionHandler() {
+ public void handle(int output, double[] posteriori) {
+ result[0] = new IntWritable(output);
+ result[1] = WritableUtils.toWritableList(posteriori);
}
- this.prevModelId = modelId;
- this.prevVM = vm;
- }
-
- try {
- vm.eval(features);
- } catch (VMRuntimeException vme) {
- throw new HiveException("failed to eval StackMachine", vme);
- } catch (Throwable e) {
- throw new HiveException("failed to eval StackMachine", e);
- }
+ });
- Double result = vm.getResult();
- if (result == null) {
- return null;
- }
- if (classification) {
- return new IntWritable(result.intValue());
- } else {
- return new DoubleWritable(result.doubleValue());
- }
- }
-
- @Override
- public void close() throws IOException {
- IOUtils.closeQuietly(codec);
+ return result;
}
}
- static final class JavascriptEvaluator implements Evaluator {
+ static final class RegressionEvaluator implements Evaluator {
- private final ScriptEngine scriptEngine;
- private final Compilable compilableEngine;
+ @Nonnull
+ private final DoubleWritable result;
+ @Nullable
private String prevModelId = null;
- private CompiledScript prevCompiled;
-
- private DeflateCodec codec = null;
+ private RegressionTree.Node rNode = null;
- JavascriptEvaluator() throws UDFArgumentException {
- ScriptEngineManager manager = new ScriptEngineManager();
- ScriptEngine engine = manager.getEngineByExtension("js");
- if (!(engine instanceof Compilable)) {
- throw new UDFArgumentException("ScriptEngine was not compilable: "
- + engine.getFactory().getEngineName() + " version "
- + engine.getFactory().getEngineVersion());
- }
- this.scriptEngine = engine;
- this.compilableEngine = (Compilable) engine;
+ RegressionEvaluator() {
+ this.result = new DoubleWritable();
}
- @Override
- public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script,
- double[] features, boolean classification) throws HiveException {
- final String scriptStr;
- if (compressed) {
- if (codec == null) {
- this.codec = new DeflateCodec(false, true);
- }
+ @Nonnull
+ public DoubleWritable evaluate(@Nonnull final String modelId, @Nonnull final Text script,
+ @Nonnull final Vector features) throws HiveException {
+ if (!modelId.equals(prevModelId)) {
+ this.prevModelId = modelId;
+ int length = script.getLength();
byte[] b = script.getBytes();
- int len = script.getLength();
- b = Base91.decode(b, 0, len);
- try {
- b = codec.decompress(b);
- } catch (IOException e) {
- throw new HiveException("decompression failed", e);
- }
- scriptStr = new String(b);
- } else {
- scriptStr = script.toString();
- }
-
- final CompiledScript compiled;
- if (modelId.equals(prevModelId)) {
- compiled = prevCompiled;
- } else {
- try {
- compiled = compilableEngine.compile(scriptStr);
- } catch (ScriptException e) {
- throw new HiveException("failed to compile: \n" + script, e);
- }
- this.prevCompiled = compiled;
- }
-
- final Bindings bindings = scriptEngine.createBindings();
- final Object result;
- try {
- bindings.put("x", features);
- result = compiled.eval(bindings);
- } catch (ScriptException se) {
- throw new HiveException("failed to evaluate: \n" + script, se);
- } catch (Throwable e) {
- throw new HiveException("failed to evaluate: \n" + script, e);
- } finally {
- bindings.clear();
- }
-
- if (result == null) {
- return null;
- }
- if (!(result instanceof Number)) {
- throw new HiveException("Got an unexpected non-number result: " + result);
- }
- if (classification) {
- Number casted = (Number) result;
- return new IntWritable(casted.intValue());
- } else {
- Number casted = (Number) result;
- return new DoubleWritable(casted.doubleValue());
+ b = Base91.decode(b, 0, length);
+ this.rNode = RegressionTree.deserializeNode(b, b.length, true);
}
- }
+ Preconditions.checkNotNull(rNode);
- @Override
- public void close() throws IOException {
- IOUtils.closeQuietly(codec);
+ double value = rNode.predict(features);
+ result.set(value);
+ return result;
}
-
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
index c0dfc1c..74a3032 100644
--- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
+++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
@@ -18,11 +18,23 @@
*/
package hivemall.smile.utils;
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.MatrixUtils;
+import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d;
+import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.VectorProcedure;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.data.Attribute.NominalAttribute;
import hivemall.smile.data.Attribute.NumericAttribute;
+import hivemall.utils.collections.lists.DoubleArrayList;
+import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.lang.mutable.MutableInt;
+import hivemall.utils.math.MathUtils;
import java.util.Arrays;
@@ -49,13 +61,14 @@ public final class SmileExtUtils {
}
final String[] opts = opt.split(",");
final int size = opts.length;
+ final NumericAttribute immutableNumAttr = new NumericAttribute();
final Attribute[] attr = new Attribute[size];
for (int i = 0; i < size; i++) {
final String type = opts[i];
if ("Q".equals(type)) {
- attr[i] = new NumericAttribute(i);
+ attr[i] = immutableNumAttr;
} else if ("C".equals(type)) {
- attr[i] = new NominalAttribute(i);
+ attr[i] = new NominalAttribute();
} else {
throw new UDFArgumentException("Unexpected type: " + type);
}
@@ -64,13 +77,55 @@ public final class SmileExtUtils {
}
@Nonnull
- public static Attribute[] attributeTypes(@Nullable Attribute[] attributes,
- @Nonnull final double[][] x) {
+ public static Attribute[] attributeTypes(@Nullable final Attribute[] attributes,
+ @Nonnull final Matrix x) {
if (attributes == null) {
- int p = x[0].length;
- attributes = new Attribute[p];
- for (int i = 0; i < p; i++) {
- attributes[i] = new NumericAttribute(i);
+ int p = x.numColumns();
+ Attribute[] newAttributes = new Attribute[p];
+ Arrays.fill(newAttributes, new NumericAttribute());
+ return newAttributes;
+ }
+
+ if (x.isRowMajorMatrix()) {
+ final VectorProcedure proc = new VectorProcedure() {
+ @Override
+ public void apply(final int j, final double value) {
+ final Attribute attr = attributes[j];
+ if (attr.type == AttributeType.NOMINAL) {
+ final int x_ij = ((int) value) + 1;
+ final int prevSize = attr.getSize();
+ if (x_ij > prevSize) {
+ attr.setSize(x_ij);
+ }
+ }
+ }
+ };
+ for (int i = 0, rows = x.numRows(); i < rows; i++) {
+ x.eachNonNullInRow(i, proc);
+ }
+ } else if (x.isColumnMajorMatrix()) {
+ final MutableInt max_x = new MutableInt(0);
+ final VectorProcedure proc = new VectorProcedure() {
+ @Override
+ public void apply(final int i, final double value) {
+ final int x_ij = (int) value;
+ if (x_ij > max_x.getValue()) {
+ max_x.setValue(x_ij);
+ }
+ }
+ };
+
+ final int size = attributes.length;
+ for (int j = 0; j < size; j++) {
+ final Attribute attr = attributes[j];
+ if (attr.type == AttributeType.NOMINAL) {
+ if (attr.getSize() != -1) {
+ continue;
+ }
+ max_x.setValue(0);
+ x.eachNonNullInColumn(j, proc);
+ attr.setSize(max_x.getValue() + 1);
+ }
}
} else {
int size = attributes.length;
@@ -81,8 +136,12 @@ public final class SmileExtUtils {
continue;
}
int max_x = 0;
- for (int i = 0; i < x.length; i++) {
- int x_ij = (int) x[i][j];
+ for (int i = 0, rows = x.numRows(); i < rows; i++) {
+ final double v = x.get(i, j, Double.NaN);
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ int x_ij = (int) v;
if (x_ij > max_x) {
max_x = x_ij;
}
@@ -97,16 +156,17 @@ public final class SmileExtUtils {
@Nonnull
public static Attribute[] convertAttributeTypes(@Nonnull final smile.data.Attribute[] original) {
final int size = original.length;
+ final NumericAttribute immutableNumAttr = new NumericAttribute();
final Attribute[] dst = new Attribute[size];
for (int i = 0; i < size; i++) {
smile.data.Attribute o = original[i];
switch (o.type) {
case NOMINAL: {
- dst[i] = new NominalAttribute(i);
+ dst[i] = new NominalAttribute();
break;
}
case NUMERIC: {
- dst[i] = new NumericAttribute(i);
+ dst[i] = immutableNumAttr;
break;
}
default:
@@ -117,23 +177,52 @@ public final class SmileExtUtils {
}
@Nonnull
- public static int[][] sort(@Nonnull final Attribute[] attributes, @Nonnull final double[][] x) {
- final int n = x.length;
- final int p = x[0].length;
+ public static ColumnMajorIntMatrix sort(@Nonnull final Attribute[] attributes,
+ @Nonnull final Matrix x) {
+ final int n = x.numRows();
+ final int p = x.numColumns();
- final double[] a = new double[n];
final int[][] index = new int[p][];
+ if (x.isSparse()) {
+ int initSize = n / 10;
+ final DoubleArrayList dlist = new DoubleArrayList(initSize);
+ final IntArrayList ilist = new IntArrayList(initSize);
+ final VectorProcedure proc = new VectorProcedure() {
+ @Override
+ public void apply(final int i, final double v) {
+ dlist.add(v);
+ ilist.add(i);
+ }
+ };
- for (int j = 0; j < p; j++) {
- if (attributes[j].type == AttributeType.NUMERIC) {
- for (int i = 0; i < n; i++) {
- a[i] = x[i][j];
+ final ColumnMajorMatrix x2 = x.toColumnMajorMatrix();
+ for (int j = 0; j < p; j++) {
+ if (attributes[j].type != AttributeType.NUMERIC) {
+ continue;
+ }
+ x2.eachNonNullInColumn(j, proc);
+ if (ilist.isEmpty()) {
+ continue;
+ }
+ int[] indexJ = ilist.toArray();
+ QuickSort.sort(dlist.array(), indexJ, indexJ.length);
+ index[j] = indexJ;
+ dlist.clear();
+ ilist.clear();
+ }
+ } else {
+ final double[] a = new double[n];
+ for (int j = 0; j < p; j++) {
+ if (attributes[j].type == AttributeType.NUMERIC) {
+ for (int i = 0; i < n; i++) {
+ a[i] = x.get(i, j);
+ }
+ index[j] = QuickSort.sort(a);
}
- index[j] = QuickSort.sort(a);
}
}
- return index;
+ return new ColumnMajorDenseIntMatrix2d(index, n);
}
@Nonnull
@@ -169,13 +258,13 @@ public final class SmileExtUtils {
}
}
- public static int computeNumInputVars(final float numVars, final double[][] x) {
+ public static int computeNumInputVars(final float numVars, @Nonnull final Matrix x) {
final int numInputVars;
if (numVars <= 0.f) {
- int dims = x[0].length;
+ int dims = x.numColumns();
numInputVars = (int) Math.ceil(Math.sqrt(dims));
} else if (numVars > 0.f && numVars <= 1.f) {
- numInputVars = (int) (numVars * x[0].length);
+ numInputVars = (int) (numVars * x.numColumns());
} else {
numInputVars = (int) numVars;
}
@@ -186,42 +275,75 @@ public final class SmileExtUtils {
return Thread.currentThread().getId() * System.nanoTime();
}
- public static void shuffle(@Nonnull final int[] x, @Nonnull final smile.math.Random rnd) {
+ public static void shuffle(@Nonnull final int[] x, @Nonnull final PRNG rnd) {
for (int i = x.length; i > 1; i--) {
int j = rnd.nextInt(i);
swap(x, i - 1, j);
}
}
- public static void shuffle(@Nonnull final double[][] x, final int[] y, @Nonnull long seed) {
- if (x.length != y.length) {
- throw new IllegalArgumentException("x.length (" + x.length + ") != y.length ("
+ @Nonnull
+ public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final int[] y, long seed) {
+ final int numRows = x.numRows();
+ if (numRows != y.length) {
+ throw new IllegalArgumentException("x.length (" + numRows + ") != y.length ("
+ y.length + ')');
}
if (seed == -1L) {
seed = generateSeed();
}
- final smile.math.Random rnd = new smile.math.Random(seed);
- for (int i = x.length; i > 1; i--) {
- int j = rnd.nextInt(i);
- swap(x, i - 1, j);
- swap(y, i - 1, j);
+
+ final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
+ if (x.swappable()) {
+ for (int i = numRows; i > 1; i--) {
+ int j = rnd.nextInt(i);
+ int k = i - 1;
+ x.swap(k, j);
+ swap(y, k, j);
+ }
+ return x;
+ } else {
+ final int[] indicies = MathUtils.permutation(numRows);
+ for (int i = numRows; i > 1; i--) {
+ int j = rnd.nextInt(i);
+ int k = i - 1;
+ swap(indicies, k, j);
+ swap(y, k, j);
+ }
+ return MatrixUtils.shuffle(x, indicies);
}
}
- public static void shuffle(@Nonnull final double[][] x, final double[] y, @Nonnull long seed) {
- if (x.length != y.length) {
- throw new IllegalArgumentException("x.length (" + x.length + ") != y.length ("
+ @Nonnull
+ public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final double[] y,
+ @Nonnull long seed) {
+ final int numRows = x.numRows();
+ if (numRows != y.length) {
+ throw new IllegalArgumentException("x.length (" + numRows + ") != y.length ("
+ y.length + ')');
}
if (seed == -1L) {
seed = generateSeed();
}
- final smile.math.Random rnd = new smile.math.Random(seed);
- for (int i = x.length; i > 1; i--) {
- int j = rnd.nextInt(i);
- swap(x, i - 1, j);
- swap(y, i - 1, j);
+
+ final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
+ if (x.swappable()) {
+ for (int i = numRows; i > 1; i--) {
+ int j = rnd.nextInt(i);
+ int k = i - 1;
+ x.swap(k, j);
+ swap(y, k, j);
+ }
+ return x;
+ } else {
+ final int[] indicies = MathUtils.permutation(numRows);
+ for (int i = numRows; i > 1; i--) {
+ int j = rnd.nextInt(i);
+ int k = i - 1;
+ swap(indicies, k, j);
+ swap(y, k, j);
+ }
+ return MatrixUtils.shuffle(x, indicies);
}
}
@@ -243,15 +365,6 @@ public final class SmileExtUtils {
x[j] = s;
}
- /**
- * Swap two elements of an array.
- */
- private static void swap(final double[][] x, final int i, final int j) {
- double[] s = x[i];
- x[i] = x[j];
- x[j] = s;
- }
-
@Nonnull
public static int[] bagsToSamples(@Nonnull final int[] bags) {
int maxIndex = -1;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/Operation.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/vm/Operation.java b/core/src/main/java/hivemall/smile/vm/Operation.java
deleted file mode 100644
index fff617f..0000000
--- a/core/src/main/java/hivemall/smile/vm/Operation.java
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.vm;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-public final class Operation {
-
- final OperationEnum op;
- final String operand;
-
- public Operation(@Nonnull OperationEnum op) {
- this(op, null);
- }
-
- public Operation(@Nonnull OperationEnum op, @Nullable String operand) {
- this.op = op;
- this.operand = operand;
- }
-
- public enum OperationEnum {
- ADD, SUB, DIV, MUL, DUP, // reserved
- PUSH, POP, GOTO, IFEQ, IFEQ2, IFGE, IFGT, IFLE, IFLT, CALL; // used
-
- static OperationEnum valueOfLowerCase(String op) {
- return OperationEnum.valueOf(op.toUpperCase());
- }
- }
-
- @Override
- public String toString() {
- return op.toString() + (operand != null ? (" " + operand) : "");
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/StackMachine.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/vm/StackMachine.java b/core/src/main/java/hivemall/smile/vm/StackMachine.java
deleted file mode 100644
index 3bf8b46..0000000
--- a/core/src/main/java/hivemall/smile/vm/StackMachine.java
+++ /dev/null
@@ -1,300 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.vm;
-
-import hivemall.utils.lang.StringUtils;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Stack;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-public final class StackMachine {
- public static final String SEP = "; ";
-
- @Nonnull
- private final List<Operation> code;
- @Nonnull
- private final Map<String, Double> valuesMap;
- @Nonnull
- private final Map<String, Integer> jumpMap;
- @Nonnull
- private final Stack<Double> programStack;
-
- /**
- * Instruction pointer
- */
- private int IP;
-
- /**
- * Stack pointer
- */
- @SuppressWarnings("unused")
- private int SP;
-
- private int codeLength;
- private boolean[] done;
- private Double result;
-
- public StackMachine() {
- this.code = new ArrayList<Operation>();
- this.valuesMap = new HashMap<String, Double>();
- this.jumpMap = new HashMap<String, Integer>();
- this.programStack = new Stack<Double>();
- this.SP = 0;
- this.result = null;
- }
-
- public void run(@Nonnull String scripts, @Nonnull double[] features) throws VMRuntimeException {
- compile(scripts);
- eval(features);
- }
-
- public void run(@Nonnull List<String> opslist, @Nonnull double[] features)
- throws VMRuntimeException {
- compile(opslist);
- eval(features);
- }
-
- public void compile(@Nonnull String scripts) throws VMRuntimeException {
- List<String> opslist = Arrays.asList(scripts.split(SEP));
- compile(opslist);
- }
-
- public void compile(@Nonnull List<String> opslist) throws VMRuntimeException {
- for (String line : opslist) {
- String[] ops = line.split(" ", -1);
- if (ops.length == 2) {
- Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]);
- code.add(new Operation(o, ops[1]));
- } else {
- Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]);
- code.add(new Operation(o));
- }
- }
-
- int size = opslist.size();
- this.codeLength = size - 1;
- this.done = new boolean[size];
- }
-
- public void eval(final double[] features) throws VMRuntimeException {
- init();
- bind(features);
- execute(0);
- }
-
- private void init() {
- valuesMap.clear();
- jumpMap.clear();
- programStack.clear();
- this.SP = 0;
- this.result = null;
- Arrays.fill(done, false);
- }
-
- private void bind(final double[] features) {
- final StringBuilder buf = new StringBuilder();
- for (int i = 0; i < features.length; i++) {
- String bindKey = buf.append("x[").append(i).append("]").toString();
- valuesMap.put(bindKey, features[i]);
- StringUtils.clear(buf);
- }
- }
-
- private void execute(int entryPoint) throws VMRuntimeException {
- valuesMap.put("end", -1.0);
- jumpMap.put("last", codeLength);
-
- IP = entryPoint;
-
- while (IP < code.size()) {
- if (done[IP]) {
- throw new VMRuntimeException("There is a infinite loop in the Machine code.");
- }
- done[IP] = true;
- Operation currentOperation = code.get(IP);
- if (!executeOperation(currentOperation)) {
- return;
- }
- }
- }
-
- @Nullable
- public Double getResult() {
- return result;
- }
-
- private Double pop() {
- SP--;
- return programStack.pop();
- }
-
- private Double push(Double val) {
- programStack.push(val);
- SP++;
- return val;
- }
-
- private boolean executeOperation(Operation currentOperation) throws VMRuntimeException {
- if (IP < 0) {
- return false;
- }
- switch (currentOperation.op) {
- case GOTO: {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- break;
- }
- case CALL: {
- double candidateIP = valuesMap.get(currentOperation.operand);
- if (candidateIP < 0) {
- evaluateBuiltinByName(currentOperation.operand);
- IP++;
- }
- break;
- }
- case IFEQ: {
- double a = pop();
- double b = pop();
- if (a == b) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case IFEQ2: {// follow the rule of smile's Math class.
- double a = pop();
- double b = pop();
- if (smile.math.Math.equals(a, b)) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case IFGE: {
- double lower = pop();
- double upper = pop();
- if (upper >= lower) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case IFGT: {
- double lower = pop();
- double upper = pop();
- if (upper > lower) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case IFLE: {
- double lower = pop();
- double upper = pop();
- if (upper <= lower) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case IFLT: {
- double lower = pop();
- double upper = pop();
- if (upper < lower) {
- IP++;
- } else {
- if (StringUtils.isInt(currentOperation.operand)) {
- IP = Integer.parseInt(currentOperation.operand);
- } else {
- IP = jumpMap.get(currentOperation.operand);
- }
- }
- break;
- }
- case POP: {
- valuesMap.put(currentOperation.operand, pop());
- IP++;
- break;
- }
- case PUSH: {
- if (StringUtils.isDouble(currentOperation.operand)) {
- push(Double.parseDouble(currentOperation.operand));
- } else {
- Double v = valuesMap.get(currentOperation.operand);
- if (v == null) {
- throw new VMRuntimeException("value is not binded: "
- + currentOperation.operand);
- }
- push(v);
- }
- IP++;
- break;
- }
- default:
- throw new VMRuntimeException("Machine code has wrong opcode :"
- + currentOperation.op);
- }
- return true;
-
- }
-
- private void evaluateBuiltinByName(String name) throws VMRuntimeException {
- if (name.equals("end")) {
- this.result = pop();
- } else {
- throw new VMRuntimeException("Machine code has wrong builin function :" + name);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java b/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java
deleted file mode 100644
index 7fc89c8..0000000
--- a/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.vm;
-
-public class VMRuntimeException extends Exception {
- private static final long serialVersionUID = -7378149197872357802L;
-
- public VMRuntimeException(String message) {
- super(message);
- }
-
- public VMRuntimeException(String message, Throwable cause) {
- super(message, cause);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java
index 1f6c324..366b74b 100644
--- a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java
+++ b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java
@@ -19,7 +19,7 @@
package hivemall.tools.mapred;
import hivemall.ftvec.ExtractFeatureUDF;
-import hivemall.utils.collections.OpenHashMap;
+import hivemall.utils.collections.maps.OpenHashMap;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArray.java b/core/src/main/java/hivemall/utils/collections/DoubleArray.java
deleted file mode 100644
index a7dfa81..0000000
--- a/core/src/main/java/hivemall/utils/collections/DoubleArray.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.io.Serializable;
-
-import javax.annotation.Nonnull;
-
-public interface DoubleArray extends Serializable {
-
- public double get(int key);
-
- public double get(int key, double valueIfKeyNotFound);
-
- public void put(int key, double value);
-
- public int size();
-
- public int keyAt(int index);
-
- @Nonnull
- public double[] toArray();
-
- @Nonnull
- public double[] toArray(boolean copy);
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java b/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java
deleted file mode 100644
index 5716212..0000000
--- a/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.Primitives;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
-
-import javax.annotation.Nonnull;
-
-public final class DoubleArray3D {
- private static final int DEFAULT_SIZE = 100 * 100 * 10; // feature * field * factor
-
- private final boolean direct;
-
- @Nonnull
- private DoubleBuffer buffer;
- private int capacity;
-
- private int size;
- // number of array in each dimension
- private int n1, n2, n3;
- // pointer to each dimension
- private int p1, p2;
-
- private boolean sanityCheck;
-
- public DoubleArray3D() {
- this(DEFAULT_SIZE, true);
- }
-
- public DoubleArray3D(int initSize, boolean direct) {
- this.direct = direct;
- this.buffer = allocate(direct, initSize);
- this.capacity = initSize;
- this.size = -1;
- this.sanityCheck = true;
- }
-
- public DoubleArray3D(int dim1, int dim2, int dim3) {
- this.direct = true;
- this.capacity = -1;
- configure(dim1, dim2, dim3);
- this.sanityCheck = true;
- }
-
- public void setSanityCheck(boolean enable) {
- this.sanityCheck = enable;
- }
-
- public void configure(final int dim1, final int dim2, final int dim3) {
- int requiredSize = cardinarity(dim1, dim2, dim3);
- if (requiredSize > capacity) {
- this.buffer = allocate(direct, requiredSize);
- this.capacity = requiredSize;
- }
- this.size = requiredSize;
- this.n1 = dim1;
- this.n2 = dim2;
- this.n3 = dim3;
- this.p1 = n2 * n3;
- this.p2 = n3;
- }
-
- public void clear() {
- buffer.clear();
- this.size = -1;
- }
-
- public int getSize() {
- return size;
- }
-
- int getCapacity() {
- return capacity;
- }
-
- public double get(final int i, final int j, final int k) {
- int idx = idx(i, j, k);
- return buffer.get(idx);
- }
-
- public void set(final int i, final int j, final int k, final double val) {
- int idx = idx(i, j, k);
- buffer.put(idx, val);
- }
-
- private int idx(final int i, final int j, final int k) {
- if (sanityCheck == false) {
- return i * p1 + j * p2 + k;
- }
-
- if (size == -1) {
- throw new IllegalStateException("Double3DArray#configure() is not called");
- }
- if (i >= n1 || i < 0) {
- throw new ArrayIndexOutOfBoundsException("Index '" + i
- + "' out of bounds for 1st dimension of size " + n1);
- }
- if (j >= n2 || j < 0) {
- throw new ArrayIndexOutOfBoundsException("Index '" + j
- + "' out of bounds for 2nd dimension of size " + n2);
- }
- if (k >= n3 || k < 0) {
- throw new ArrayIndexOutOfBoundsException("Index '" + k
- + "' out of bounds for 3rd dimension of size " + n3);
- }
- final int idx = i * p1 + j * p2 + k;
- if (idx >= size) {
- throw new IndexOutOfBoundsException("Computed internal index '" + idx
- + "' exceeds buffer size '" + size + "' where i=" + i + ", j=" + j + ", k=" + k);
- }
- return idx;
- }
-
- private static int cardinarity(final int dim1, final int dim2, final int dim3) {
- if (dim1 <= 0 || dim2 <= 0 || dim3 <= 0) {
- throw new IllegalArgumentException("Detected negative dimension size. dim1=" + dim1
- + ", dim2=" + dim2 + ", dim3=" + dim3);
- }
- return dim1 * dim2 * dim3;
- }
-
- @Nonnull
- private static DoubleBuffer allocate(final boolean direct, final int size) {
- int bytes = size * Primitives.DOUBLE_BYTES;
- ByteBuffer buf = direct ? ByteBuffer.allocateDirect(bytes) : ByteBuffer.allocate(bytes);
- return buf.asDoubleBuffer();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java b/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java
deleted file mode 100644
index afdc251..0000000
--- a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java
+++ /dev/null
@@ -1,168 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.io.Closeable;
-import java.io.Serializable;
-
-import javax.annotation.Nonnull;
-
-public final class DoubleArrayList implements Serializable, Closeable {
- private static final long serialVersionUID = -8155789759545975413L;
- public static final int DEFAULT_CAPACITY = 12;
-
- /** array entity */
- private double[] data;
- private int used;
-
- public DoubleArrayList() {
- this(DEFAULT_CAPACITY);
- }
-
- public DoubleArrayList(int size) {
- this.data = new double[size];
- this.used = 0;
- }
-
- public DoubleArrayList(double[] initValues) {
- this.data = initValues;
- this.used = initValues.length;
- }
-
- public void add(double value) {
- if (used >= data.length) {
- expand(used + 1);
- }
- data[used++] = value;
- }
-
- public void add(double[] values) {
- final int needs = used + values.length;
- if (needs >= data.length) {
- expand(needs);
- }
- System.arraycopy(values, 0, data, used, values.length);
- this.used = needs;
- }
-
- /**
- * dynamic expansion.
- */
- private void expand(int max) {
- while (data.length < max) {
- final int len = data.length;
- double[] newArray = new double[len * 2];
- System.arraycopy(data, 0, newArray, 0, len);
- this.data = newArray;
- }
- }
-
- public double remove() {
- return data[--used];
- }
-
- public double remove(int index) {
- final double ret;
- if (index > used) {
- throw new IndexOutOfBoundsException();
- } else if (index == used) {
- ret = data[--used];
- } else { // index < used
- // removed value
- ret = data[index];
- final double[] newarray = new double[--used];
- // prefix
- System.arraycopy(data, 0, newarray, 0, index - 1);
- // appendix
- System.arraycopy(data, index + 1, newarray, index, used - index);
- // set fields.
- this.data = newarray;
- }
- return ret;
- }
-
- public void set(int index, double value) {
- if (index > used) {
- throw new IllegalArgumentException("Index MUST be less than \"size()\".");
- } else if (index == used) {
- ++used;
- }
- data[index] = value;
- }
-
- public double get(int index) {
- if (index >= used)
- throw new IndexOutOfBoundsException();
- return data[index];
- }
-
- public double fastGet(int index) {
- return data[index];
- }
-
- public int size() {
- return used;
- }
-
- public boolean isEmpty() {
- return used == 0;
- }
-
- public void clear() {
- used = 0;
- }
-
- @Nonnull
- public double[] toArray() {
- return toArray(false);
- }
-
- @Nonnull
- public double[] toArray(boolean close) {
- final double[] newArray = new double[used];
- System.arraycopy(data, 0, newArray, 0, used);
- if (close) {
- close();
- }
- return newArray;
- }
-
- public double[] array() {
- return data;
- }
-
- @Override
- public String toString() {
- final StringBuilder buf = new StringBuilder();
- buf.append('[');
- for (int i = 0; i < used; i++) {
- if (i != 0) {
- buf.append(", ");
- }
- buf.append(data[i]);
- }
- buf.append(']');
- return buf.toString();
- }
-
- @Override
- public void close() {
- this.data = null;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/FixedIntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java b/core/src/main/java/hivemall/utils/collections/FixedIntArray.java
deleted file mode 100644
index 927ee83..0000000
--- a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-
-/**
- * A fixed INT array that has keys greater than or equals to 0.
- */
-public final class FixedIntArray implements IntArray {
- private static final long serialVersionUID = -1450212841013810240L;
-
- @Nonnull
- private final int[] array;
- private final int size;
-
- public FixedIntArray(@Nonnull int size) {
- this.array = new int[size];
- this.size = size;
- }
-
- public FixedIntArray(@Nonnull int[] array) {
- this.array = array;
- this.size = array.length;
- }
-
- @Override
- public int get(int index) {
- return array[index];
- }
-
- @Override
- public int get(int index, int valueIfKeyNotFound) {
- if (index >= size) {
- return valueIfKeyNotFound;
- }
- return array[index];
- }
-
- @Override
- public void put(int index, int value) {
- array[index] = value;
- }
-
- @Override
- public int size() {
- return array.length;
- }
-
- @Override
- public int keyAt(int index) {
- return index;
- }
-
- @Override
- public int[] toArray() {
- return toArray(true);
- }
-
- @Override
- public int[] toArray(boolean copy) {
- if (copy) {
- return Arrays.copyOf(array, size);
- } else {
- return array;
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/FloatArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/FloatArrayList.java b/core/src/main/java/hivemall/utils/collections/FloatArrayList.java
deleted file mode 100644
index cfdf504..0000000
--- a/core/src/main/java/hivemall/utils/collections/FloatArrayList.java
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.io.Serializable;
-
-public final class FloatArrayList implements Serializable {
- private static final long serialVersionUID = 8764828070342317585L;
-
- public static final int DEFAULT_CAPACITY = 12;
-
- /** array entity */
- private float[] data;
- private int used;
-
- public FloatArrayList() {
- this(DEFAULT_CAPACITY);
- }
-
- public FloatArrayList(int size) {
- this.data = new float[size];
- this.used = 0;
- }
-
- public FloatArrayList(float[] initValues) {
- this.data = initValues;
- this.used = initValues.length;
- }
-
- public void add(float value) {
- if (used >= data.length) {
- expand(used + 1);
- }
- data[used++] = value;
- }
-
- public void add(float[] values) {
- final int needs = used + values.length;
- if (needs >= data.length) {
- expand(needs);
- }
- System.arraycopy(values, 0, data, used, values.length);
- this.used = needs;
- }
-
- /**
- * dynamic expansion.
- */
- private void expand(int max) {
- while (data.length < max) {
- final int len = data.length;
- float[] newArray = new float[len * 2];
- System.arraycopy(data, 0, newArray, 0, len);
- this.data = newArray;
- }
- }
-
- public float remove() {
- return data[--used];
- }
-
- public float remove(int index) {
- final float ret;
- if (index > used) {
- throw new IndexOutOfBoundsException();
- } else if (index == used) {
- ret = data[--used];
- } else { // index < used
- // removed value
- ret = data[index];
- final float[] newarray = new float[--used];
- // prefix
- System.arraycopy(data, 0, newarray, 0, index - 1);
- // appendix
- System.arraycopy(data, index + 1, newarray, index, used - index);
- // set fields.
- this.data = newarray;
- }
- return ret;
- }
-
- public void set(int index, float value) {
- if (index > used) {
- throw new IllegalArgumentException("Index MUST be less than \"size()\".");
- } else if (index == used) {
- ++used;
- }
- data[index] = value;
- }
-
- public float get(int index) {
- if (index >= used)
- throw new IndexOutOfBoundsException();
- return data[index];
- }
-
- public float fastGet(int index) {
- return data[index];
- }
-
- public int size() {
- return used;
- }
-
- public boolean isEmpty() {
- return used == 0;
- }
-
- public void clear() {
- this.used = 0;
- }
-
- public float[] toArray() {
- final float[] newArray = new float[used];
- System.arraycopy(data, 0, newArray, 0, used);
- return newArray;
- }
-
- public float[] array() {
- return data;
- }
-
- @Override
- public String toString() {
- final StringBuilder buf = new StringBuilder();
- buf.append('[');
- for (int i = 0; i < used; i++) {
- if (i != 0) {
- buf.append(", ");
- }
- buf.append(data[i]);
- }
- buf.append(']');
- return buf.toString();
- }
-}
[05/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
new file mode 100644
index 0000000..f847b15
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public class Int2FloatOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected float defaultReturnValue = -1.f;
+
+ protected int[] _keys;
+ protected float[] _values;
+ protected byte[] _states;
+
+ protected Int2FloatOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = new float[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Int2FloatOpenHashTable(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Int2FloatOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Int2FloatOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(float v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(int key) {
+ return findKey(key) >= 0;
+ }
+
+ /**
+ * @return -1.f if not found
+ */
+ public float get(int key) {
+ int i = findKey(key);
+ if (i < 0) {
+ return defaultReturnValue;
+ }
+ return _values[i];
+ }
+
+ public float put(int key, float value) {
+ int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ int[] keys = _keys;
+ float[] values = _values;
+ byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, int key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(int key) {
+ int[] keys = _keys;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public float remove(int key) {
+ int[] keys = _keys;
+ float[] values = _values;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ int[] newkeys = new int[newCapacity];
+ float[] newValues = new float[newCapacity];
+ byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ int k = _keys[i];
+ float v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(int key) {
+ return key & 0x7fffffff;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeInt(i.getKey());
+ out.writeFloat(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ int keylen = in.readInt();
+ int[] keys = new int[keylen];
+ float[] values = new float[keylen];
+ byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ int k = in.readInt();
+ float v = in.readFloat();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public float getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public float getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
new file mode 100644
index 0000000..5e9e812
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public final class Int2IntOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected int defaultReturnValue = -1;
+
+ protected int[] _keys;
+ protected int[] _values;
+ protected byte[] _states;
+
+ protected Int2IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = new int[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Int2IntOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Int2IntOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Int2IntOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(int v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final int key) {
+ return findKey(key) >= 0;
+ }
+
+ /**
+ * @return -1.f if not found
+ */
+ public int get(final int key) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return defaultReturnValue;
+ }
+ return _values[i];
+ }
+
+ public int put(final int key, final int value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public int remove(final int key) {
+ final int[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final int[] newkeys = new int[newCapacity];
+ final int[] newValues = new int[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ int k = _keys[i];
+ int v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final int key) {
+ return key & 0x7fffffff;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeInt(i.getKey());
+ out.writeInt(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final int[] keys = new int[keylen];
+ final int[] values = new int[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ int k = in.readInt();
+ int v = in.readInt();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public int getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public int getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
new file mode 100644
index 0000000..68eb42f
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
@@ -0,0 +1,500 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.codec.VariableByteCodec;
+import hivemall.utils.codec.ZigZagLEB128Codec;
+import hivemall.utils.math.Primes;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public class Int2LongOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ public static final int DEFAULT_SIZE = 65536;
+ public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int[] _keys;
+ protected long[] _values;
+ protected byte[] _states;
+
+ protected int _used;
+ protected int _threshold;
+ protected long defaultReturnValue = -1L;
+
+ /**
+ * Constructor for Externalizable. Should not be called otherwise.
+ */
+ public Int2LongOpenHashTable() {// for Externalizable
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public Int2LongOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Int2LongOpenHashTable(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ protected Int2LongOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = new long[actualSize];
+ this._states = new byte[actualSize];
+ this._used = 0;
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Int2LongOpenHashTable(@Nonnull int[] keys, @Nonnull long[] values,
+ @Nonnull byte[] states, int used) {
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ this._used = used;
+ this._threshold = keys.length;
+ }
+
+ @Nonnull
+ public static Int2LongOpenHashTable newInstance() {
+ return new Int2LongOpenHashTable(DEFAULT_SIZE);
+ }
+
+ public void defaultReturnValue(long v) {
+ this.defaultReturnValue = v;
+ }
+
+ @Nonnull
+ public int[] getKeys() {
+ return _keys;
+ }
+
+ @Nonnull
+ public long[] getValues() {
+ return _values;
+ }
+
+ @Nonnull
+ public byte[] getStates() {
+ return _states;
+ }
+
+ public boolean containsKey(int key) {
+ return findKey(key) >= 0;
+ }
+
+ /**
+ * @return -1.f if not found
+ */
+ public long get(int key) {
+ int i = findKey(key);
+ if (i < 0) {
+ return defaultReturnValue;
+ }
+ return _values[i];
+ }
+
+ public long put(int key, long value) {
+ int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ int[] keys = _keys;
+ long[] values = _values;
+ byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ long old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ long old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, int key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(int key) {
+ int[] keys = _keys;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public long remove(int key) {
+ int[] keys = _keys;
+ long[] values = _values;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ long old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ long old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public int capacity() {
+ return _keys.length;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ int[] newkeys = new int[newCapacity];
+ long[] newValues = new long[newCapacity];
+ byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ int k = _keys[i];
+ long v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(int key) {
+ return key & 0x7fffffff;
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ final int[] keys = _keys;
+ final int size = keys.length;
+ out.writeInt(size);
+
+ final byte[] states = _states;
+ writeStates(states, out);
+
+ final long[] values = _values;
+ for (int i = 0; i < size; i++) {
+ if (states[i] != FULL) {
+ continue;
+ }
+ ZigZagLEB128Codec.writeSignedInt(keys[i], out);
+ ZigZagLEB128Codec.writeSignedLong(values[i], out);
+ }
+ }
+
+ @Nonnull
+ private static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out)
+ throws IOException {
+ // write empty states's indexes differentially
+ final int size = status.length;
+ int cardinarity = 0;
+ for (int i = 0; i < size; i++) {
+ if (status[i] != FULL) {
+ cardinarity++;
+ }
+ }
+ out.writeInt(cardinarity);
+ if (cardinarity == 0) {
+ return;
+ }
+ int prev = 0;
+ for (int i = 0; i < size; i++) {
+ if (status[i] != FULL) {
+ int diff = i - prev;
+ assert (diff >= 0);
+ VariableByteCodec.encodeUnsignedInt(diff, out);
+ prev = i;
+ }
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int size = in.readInt();
+ final int[] keys = new int[size];
+ final long[] values = new long[size];
+ final byte[] states = new byte[size];
+ readStates(in, states);
+
+ for (int i = 0; i < size; i++) {
+ if (states[i] != FULL) {
+ continue;
+ }
+ keys[i] = ZigZagLEB128Codec.readSignedInt(in);
+ values[i] = ZigZagLEB128Codec.readSignedLong(in);
+ }
+
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ @Nonnull
+ private static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status)
+ throws IOException {
+ // read non-empty states differentially
+ final int cardinarity = in.readInt();
+ Arrays.fill(status, IntOpenHashTable.FULL);
+ int prev = 0;
+ for (int j = 0; j < cardinarity; j++) {
+ int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
+ status[i] = IntOpenHashTable.FREE;
+ prev = i;
+ }
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public long getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public long getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
new file mode 100644
index 0000000..d7ae8d6
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
@@ -0,0 +1,467 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public class IntOpenHashMap<V> implements Externalizable {
+ private static final long serialVersionUID = -8162355845665353513L;
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+
+ protected int[] _keys;
+ protected V[] _values;
+ protected byte[] _states;
+
+ @SuppressWarnings("unchecked")
+ protected IntOpenHashMap(int size, float loadFactor, float growFactor, boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = (V[]) new Object[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = Math.round(actualSize * _loadFactor);
+ }
+
+ public IntOpenHashMap(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public IntOpenHashMap(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public IntOpenHashMap() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public boolean containsKey(int key) {
+ return findKey(key) >= 0;
+ }
+
+ public final V get(final int key) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return null;
+ }
+ recordAccess(i);
+ return _values[i];
+ }
+
+ public V put(final int key, final V value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ final boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final V[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ recordAccess(keyIdx);
+ return old;
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ recordAccess(keyIdx);
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ postAddEntry(keyIdx);
+ return null;
+ }
+
+ public V putIfAbsent(final int key, final V value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ final boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final V[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// second hashing
+ if (keys[keyIdx] == key) {
+ return values[keyIdx];
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return values[keyIdx];
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ _used++;
+ postAddEntry(keyIdx);
+ return null;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, int key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected void postAddEntry(int index) {}
+
+ private int findKey(int key) {
+ int[] keys = _keys;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public V remove(int key) {
+ int[] keys = _keys;
+ V[] values = _values;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ recordRemoval(keyIdx);
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return null;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ recordRemoval(keyIdx);
+ return old;
+ }
+ }
+ }
+ return null;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ @SuppressWarnings("unchecked")
+ public IMapIterator<V> entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator<V> i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ private void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ @SuppressWarnings("unchecked")
+ protected void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final int[] oldKeys = _keys;
+ final V[] oldValues = _values;
+ final byte[] oldStates = _states;
+ int[] newkeys = new int[newCapacity];
+ V[] newValues = (V[]) new Object[newCapacity];
+ byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (oldStates[i] == FULL) {
+ used++;
+ int k = oldKeys[i];
+ V v = oldValues[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(int key) {
+ return key & 0x7fffffff;
+ }
+
+ protected void recordAccess(int idx) {};
+
+ protected void recordRemoval(int idx) {};
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator<V> i = entries();
+ while (i.next() != -1) {
+ out.writeInt(i.getKey());
+ out.writeObject(i.getValue());
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ int keylen = in.readInt();
+ int[] keys = new int[keylen];
+ V[] values = (V[]) new Object[keylen];
+ byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ int k = in.readInt();
+ V v = (V) in.readObject();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator<V> {
+
+ public boolean hasNext();
+
+ public int next();
+
+ public int getKey();
+
+ public V getValue();
+
+ }
+
+ @SuppressWarnings("rawtypes")
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public V getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
new file mode 100644
index 0000000..dcb64d1
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+import java.util.HashMap;
+
+import javax.annotation.Nonnull;
+
+/**
+ * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
+ */
+public final class IntOpenHashTable<V> implements Externalizable {
+
+ public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ public static final byte FREE = 0;
+ public static final byte FULL = 1;
+ public static final byte REMOVED = 2;
+
+ protected/* final */float _loadFactor;
+ protected/* final */float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+
+ protected int[] _keys;
+ protected V[] _values;
+ protected byte[] _states;
+
+ public IntOpenHashTable() {} // for Externalizable
+
+ public IntOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
+ }
+
+ @SuppressWarnings("unchecked")
+ public IntOpenHashTable(int size, float loadFactor, float growFactor) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = Primes.findLeastPrimeNumber(size);
+ this._keys = new int[actualSize];
+ this._values = (V[]) new Object[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = Math.round(actualSize * _loadFactor);
+ }
+
+ public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[] states,
+ int used) {
+ this._used = used;
+ this._threshold = keys.length;
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public int[] getKeys() {
+ return _keys;
+ }
+
+ public Object[] getValues() {
+ return _values;
+ }
+
+ public byte[] getStates() {
+ return _states;
+ }
+
+ public boolean containsKey(final int key) {
+ return findKey(key) >= 0;
+ }
+
+ public V get(final int key) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return null;
+ }
+ return _values[i];
+ }
+
+ public V put(final int key, final V value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final V[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {
+ if (keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return null;
+ }
+
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, int key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// filled enough
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public V remove(final int key) {
+ final int[] keys = _keys;
+ final V[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return null;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return null;
+ }
+
+ @Nonnull
+ public IMapIterator<V> entries() {
+ return new MapIterator();
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public int capacity() {
+ return _keys.length;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ protected void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final int[] newkeys = new int[newCapacity];
+ final V[] newValues = (V[]) new Object[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ int k = _keys[i];
+ V v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newStates[keyIdx] = FULL;
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final int key) {
+ return key & 0x7fffffff;
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeFloat(_loadFactor);
+ out.writeFloat(_growFactor);
+ out.writeInt(_used);
+
+ final int size = _keys.length;
+ out.writeInt(size);
+
+ for (int i = 0; i < size; i++) {
+ out.writeInt(_keys[i]);
+ out.writeObject(_values[i]);
+ out.writeByte(_states[i]);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._loadFactor = in.readFloat();
+ this._growFactor = in.readFloat();
+ this._used = in.readInt();
+
+ final int size = in.readInt();
+ final int[] keys = new int[size];
+ final Object[] values = new Object[size];
+ final byte[] states = new byte[size];
+ for (int i = 0; i < size; i++) {
+ keys[i] = in.readInt();
+ values[i] = in.readObject();
+ states[i] = in.readByte();
+ }
+ this._threshold = size;
+ this._keys = keys;
+ this._values = (V[]) values;
+ this._states = states;
+ }
+
+ public interface IMapIterator<V> {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public V getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator<V> {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public V getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/LRUMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/LRUMap.java b/core/src/main/java/hivemall/utils/collections/maps/LRUMap.java
new file mode 100644
index 0000000..84679c7
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/LRUMap.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+public class LRUMap<K, V> extends LinkedHashMap<K, V> {
+ private static final long serialVersionUID = -7708264099645977733L;
+
+ private final int cacheSize;
+
+ public LRUMap(int cacheSize) {
+ this(cacheSize, 0.75f, cacheSize);
+ }
+
+ public LRUMap(int capacity, float loadFactor, int cacheSize) {
+ super(capacity, loadFactor, true);
+ this.cacheSize = cacheSize;
+ }
+
+ protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
+ return size() > cacheSize;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
new file mode 100644
index 0000000..c758824
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
@@ -0,0 +1,445 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public final class Long2DoubleOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected double defaultReturnValue = 0.d;
+
+ protected long[] _keys;
+ protected double[] _values;
+ protected byte[] _states;
+
+ protected Long2DoubleOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new long[actualSize];
+ this._values = new double[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Long2DoubleOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Long2DoubleOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Long2DoubleOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(double v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final long key) {
+ return _findKey(key) >= 0;
+ }
+
+ /**
+ * @return defaultReturnValue if not found
+ */
+ public double get(final long key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public double get(final long key, final double defaultValue) {
+ final int i = _findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public double _get(final int index) {
+ if (index < 0) {
+ return defaultReturnValue;
+ }
+ return _values[index];
+ }
+
+ public double _set(final int index, final double value) {
+ double old = _values[index];
+ _values[index] = value;
+ return old;
+ }
+
+ public double _remove(final int index) {
+ _states[index] = REMOVED;
+ --_used;
+ return _values[index];
+ }
+
+ public double put(final long key, final double value) {
+ return put(key, value, defaultReturnValue);
+ }
+
+ public double put(final long key, final double value, final double defaultValue) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final long key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return -1 if not found
+ */
+ public int _findKey(final long key) {
+ final long[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public double remove(final long key) {
+ final long[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final long[] newkeys = new long[newCapacity];
+ final double[] newValues = new double[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ long k = _keys[i];
+ double v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final long key) {
+ return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeLong(i.getKey());
+ out.writeDouble(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final long[] keys = new long[keylen];
+ final double[] values = new double[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ long k = in.readLong();
+ double v = in.readDouble();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public long getKey();
+
+ public double getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public long getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public double getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
[02/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index 5f8518b..d682093 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -18,15 +18,23 @@
*/
package hivemall.smile.classification;
+import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF;
+import hivemall.utils.codec.Base91;
import hivemall.utils.lang.mutable.MutableInt;
import java.io.BufferedInputStream;
+import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
+import java.io.InputStreamReader;
import java.net.URL;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
+
+import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.Collector;
@@ -34,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
@@ -43,7 +53,7 @@ import smile.data.parser.ArffParser;
public class RandomForestClassifierUDTFTest {
@Test
- public void testIris() throws IOException, ParseException, HiveException {
+ public void testIrisDense() throws IOException, ParseException, HiveException {
URL url = new URL(
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
InputStream is = new BufferedInputStream(url.openStream());
@@ -85,4 +95,278 @@ public class RandomForestClassifierUDTFTest {
Assert.assertEquals(49, count.getValue());
}
+ @Test
+ public void testIrisSparse() throws IOException, ParseException, HiveException {
+ URL url = new URL(
+ "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final MutableInt count = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ count.addValue(1);
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(49, count.getValue());
+ }
+
+ @Test
+ public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException {
+ String urlString = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
+ DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString);
+ DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString);
+
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+
+ int diff = 0;
+ for (int i = 0; i < size; i++) {
+ if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
+ diff++;
+ }
+ }
+
+ Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
+ }
+
+ private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<Double> xi = new ArrayList<Double>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < x[i].length; j++) {
+ xi.add(j, x[i][j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ return node;
+ }
+
+ private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString)
+ throws IOException, ParseException, HiveException {
+ URL url = new URL(urlString);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(4);
+
+ AttributeDataset iris = arffParser.parse(is);
+ int size = iris.size();
+ double[][] x = iris.toArray(new double[size][]);
+ int[] y = iris.toArray(new int[size]);
+
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ final List<String> xi = new ArrayList<String>(x[0].length);
+ for (int i = 0; i < size; i++) {
+ final double[] row = x[i];
+ for (int j = 0; j < row.length; j++) {
+ xi.add(j + ":" + row[j]);
+ }
+ udtf.process(new Object[] {xi, y[i]});
+ xi.clear();
+ }
+
+ final Text[] placeholder = new Text[1];
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ placeholder[0] = (Text) forward[2];
+ }
+ };
+
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Text modelTxt = placeholder[0];
+ Assert.assertNotNull(modelTxt);
+
+ byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
+ DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ return node;
+ }
+
+ @Test
+ public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
+ final int numTrees = 10;
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-stratified_sampling -seed 71 -trees " + numTrees);
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+
+ BufferedReader news20 = readFile("news20-multiclass.gz");
+ ArrayList<String> features = new ArrayList<String>();
+ String line = news20.readLine();
+ while (line != null) {
+ StringTokenizer tokens = new StringTokenizer(line, " ");
+ int label = Integer.parseInt(tokens.nextToken());
+ while (tokens.hasMoreTokens()) {
+ features.add(tokens.nextToken());
+ }
+ Assert.assertFalse(features.isEmpty());
+ udtf.process(new Object[] {features, label});
+
+ features.clear();
+ line = news20.readLine();
+ }
+ news20.close();
+
+ final MutableInt count = new MutableInt(0);
+ final MutableInt oobErrors = new MutableInt(0);
+ final MutableInt oobTests = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ oobErrors.addValue(((IntWritable) forward[4]).get());
+ oobTests.addValue(((IntWritable) forward[5]).get());
+ count.addValue(1);
+ }
+ };
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(numTrees, count.getValue());
+ float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
+ // TODO why multi-class classification so bad??
+ Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
+ }
+
+ @Test
+ public void testNews20BinarySparse() throws IOException, ParseException, HiveException {
+ final int numTrees = 10;
+ RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
+ ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees "
+ + numTrees);
+ udtf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
+
+ BufferedReader news20 = readFile("news20-small.binary.gz");
+ ArrayList<String> features = new ArrayList<String>();
+ String line = news20.readLine();
+ while (line != null) {
+ StringTokenizer tokens = new StringTokenizer(line, " ");
+ int label = Integer.parseInt(tokens.nextToken());
+ if (label == -1) {
+ label = 0;
+ }
+ while (tokens.hasMoreTokens()) {
+ features.add(tokens.nextToken());
+ }
+ if (!features.isEmpty()) {
+ udtf.process(new Object[] {features, label});
+ features.clear();
+ }
+ line = news20.readLine();
+ }
+ news20.close();
+
+ final MutableInt count = new MutableInt(0);
+ final MutableInt oobErrors = new MutableInt(0);
+ final MutableInt oobTests = new MutableInt(0);
+ Collector collector = new Collector() {
+ public void collect(Object input) throws HiveException {
+ Object[] forward = (Object[]) input;
+ oobErrors.addValue(((IntWritable) forward[4]).get());
+ oobTests.addValue(((IntWritable) forward[5]).get());
+ count.addValue(1);
+ }
+ };
+ udtf.setCollector(collector);
+ udtf.close();
+
+ Assert.assertEquals(numTrees, count.getValue());
+ float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
+ Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3);
+ }
+
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ InputStream is = KernelExpansionPassiveAggressiveUDTF.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
index 20f44b3..eae625d 100644
--- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
@@ -18,7 +18,16 @@
*/
package hivemall.smile.regression;
+import hivemall.math.matrix.Matrix;
+import hivemall.math.matrix.builders.CSRMatrixBuilder;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
+import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.data.Attribute;
+import hivemall.smile.data.Attribute.NumericAttribute;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.junit.Assert;
@@ -30,7 +39,7 @@ import smile.validation.LOOCV;
public class RegressionTreeTest {
@Test
- public void testPredict() {
+ public void testPredictDense() {
double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
{259.426, 232.5, 145.6, 108.632, 1948, 61.122},
@@ -53,10 +62,51 @@ public class RegressionTreeTest {
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
- for (int i = 0; i < attrs.length; i++) {
- attrs[i] = new Attribute.NumericAttribute(i);
+ Arrays.fill(attrs, new NumericAttribute());
+
+ int n = longley.length;
+ LOOCV loocv = new LOOCV(n);
+ double rss = 0.0;
+ for (int i = 0; i < n; i++) {
+ double[][] trainx = Math.slice(longley, loocv.train[i]);
+ double[] trainy = Math.slice(y, loocv.train[i]);
+ int maxLeafs = 10;
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs,
+ RandomNumberGeneratorFactory.createPRNG(i));
+
+ double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]);
+ rss += r * r;
}
+ Assert.assertTrue("MSE = " + (rss / n), (rss / n) < 42);
+ }
+
+ @Test
+ public void testPredictSparse() {
+
+ double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019},
+ {419.180, 282.2, 285.7, 118.734, 1956, 67.857},
+ {442.769, 293.6, 279.8, 120.445, 1957, 68.169},
+ {444.546, 468.1, 263.7, 121.950, 1958, 66.513},
+ {482.704, 381.3, 255.2, 123.366, 1959, 68.655},
+ {502.601, 393.1, 251.4, 125.368, 1960, 69.564},
+ {518.173, 480.6, 257.2, 127.852, 1961, 69.331},
+ {554.894, 400.7, 282.7, 130.081, 1962, 70.551}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
+ 112.6, 114.2, 115.7, 116.9};
+
+ Attribute[] attrs = new Attribute[longley[0].length];
+ Arrays.fill(attrs, new NumericAttribute());
+
int n = longley.length;
LOOCV loocv = new LOOCV(n);
double rss = 0.0;
@@ -64,8 +114,8 @@ public class RegressionTreeTest {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = 10;
- smile.math.Random rand = new smile.math.Random(i);
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs, rand);
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, false), trainy,
+ maxLeafs, RandomNumberGeneratorFactory.createPRNG(i));
double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]);
rss += r * r;
@@ -98,9 +148,7 @@ public class RegressionTreeTest {
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
- for (int i = 0; i < attrs.length; i++) {
- attrs[i] = new Attribute.NumericAttribute(i);
- }
+ Arrays.fill(attrs, new NumericAttribute());
int n = longley.length;
LOOCV loocv = new LOOCV(n);
@@ -108,7 +156,7 @@ public class RegressionTreeTest {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = Integer.MAX_VALUE;
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs);
+ RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs);
byte[] b = tree.predictSerCodegen(true);
RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true);
@@ -119,4 +167,19 @@ public class RegressionTreeTest {
Assert.assertEquals(expected, actual, 0.d);
}
}
+
+ @Nonnull
+ private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
+ if (dense) {
+ return new RowMajorDenseMatrix2d(x, x[0].length);
+ } else {
+ int numRows = x.length;
+ CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
+ for (int i = 0; i < numRows; i++) {
+ builder.nextRow(x[i]);
+ }
+ return builder.buildMatrix();
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
index 504ea86..65feeeb 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
@@ -18,13 +18,12 @@
*/
package hivemall.smile.tools;
-import static org.junit.Assert.assertEquals;
-import hivemall.smile.ModelType;
+import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
-import hivemall.smile.vm.StackMachine;
+import hivemall.utils.codec.Base91;
import hivemall.utils.lang.ArrayUtils;
import java.io.BufferedInputStream;
@@ -42,6 +41,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.junit.Assert;
import org.junit.Test;
import smile.data.AttributeDataset;
@@ -49,7 +50,7 @@ import smile.data.parser.ArffParser;
import smile.math.Math;
import smile.validation.CrossValidation;
import smile.validation.LOOCV;
-import smile.validation.Validation;
+import smile.validation.RMSE;
public class TreePredictUDFTest {
private static final boolean DEBUG = false;
@@ -76,8 +77,9 @@ public class TreePredictUDFTest {
int[] trainy = Math.slice(y, loocv.train[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
- DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4);
- assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
+ DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ x[0].length), trainy, 4);
+ Assert.assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]]));
}
}
@@ -103,10 +105,11 @@ public class TreePredictUDFTest {
double[][] testx = Math.slice(datax, cv.test[i]);
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20);
+ RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ trainx[0].length), trainy, 20);
for (int j = 0; j < testx.length; j++) {
- assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
+ Assert.assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0);
}
}
}
@@ -142,52 +145,60 @@ public class TreePredictUDFTest {
}
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes());
- RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20);
- debugPrint(String.format("RMSE = %.4f\n", Validation.test(tree, testx, testy)));
+ RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx,
+ trainx[0].length), trainy, 20);
+ debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
for (int i = m; i < n; i++) {
- assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
+ Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0);
}
}
+ private static <T> double rmse(RegressionTree regression, double[][] x, double[] y) {
+ final int n = x.length;
+ final double[] predictions = new double[n];
+ for (int i = 0; i < n; i++) {
+ predictions[i] = regression.predict(x[i]);
+ }
+ return new RMSE().measure(y, predictions);
+ }
+
private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException {
- String opScript = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(opScript);
+ byte[] b = tree.predictSerCodegen(true);
+ byte[] encoded = Base91.encode(b);
+ Text model = new Text(encoded);
TreePredictUDF udf = new TreePredictUDF();
udf.initialize(new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.writableStringObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)});
DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()),
- new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)),
+ new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)),
new DeferredJavaObject(true)};
- IntWritable result = (IntWritable) udf.evaluate(arguments);
+ Object[] result = (Object[]) udf.evaluate(arguments);
udf.close();
- return result.get();
+ return ((IntWritable) result[0]).get();
}
private static double evalPredict(RegressionTree tree, double[] x) throws HiveException,
IOException {
- String opScript = tree.predictOpCodegen(StackMachine.SEP);
- debugPrint(opScript);
+ byte[] b = tree.predictSerCodegen(true);
+ byte[] encoded = Base91.encode(b);
+ Text model = new Text(encoded);
TreePredictUDF udf = new TreePredictUDF();
udf.initialize(new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- PrimitiveObjectInspectorFactory.javaIntObjectInspector,
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.writableStringObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)});
DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"),
- new DeferredJavaObject(ModelType.opscode.getId()),
- new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)),
+ new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)),
new DeferredJavaObject(false)};
DoubleWritable result = (DoubleWritable) udf.evaluate(arguments);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java b/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
deleted file mode 100644
index 4a2dcd8..0000000
--- a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.vm;
-
-import static org.junit.Assert.assertEquals;
-import hivemall.utils.io.IOUtils;
-
-import java.io.BufferedInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.URL;
-import java.text.ParseException;
-import java.util.ArrayList;
-
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.junit.Assert;
-import org.junit.Test;
-
-public class StackMachineTest {
- private static final boolean DEBUG = false;
-
- @Test
- public void testFindInfinteLoop() throws IOException, ParseException, HiveException,
- VMRuntimeException {
- // Sample of machine code having infinite loop
- ArrayList<String> opScript = new ArrayList<String>();
- opScript.add("push 2.0");
- opScript.add("push 1.0");
- opScript.add("iflt 0");
- opScript.add("push 1");
- opScript.add("call end");
- debugPrint(opScript);
- double[] x = new double[0];
- StackMachine sm = new StackMachine();
- try {
- sm.run(opScript, x);
- Assert.fail("VMRuntimeException is expected");
- } catch (VMRuntimeException ex) {
- assertEquals("There is a infinite loop in the Machine code.", ex.getMessage());
- }
- }
-
- @Test
- public void testLargeOpcodes() throws IOException, ParseException, HiveException,
- VMRuntimeException {
- URL url = new URL(
- "https://gist.githubusercontent.com/myui/b1a8e588f5750e3b658c/raw/a4074d37400dab2b13a2f43d81f5166188d3461a/vmtest01.txt");
- InputStream is = new BufferedInputStream(url.openStream());
- String opScript = IOUtils.toString(is);
-
- StackMachine sm = new StackMachine();
- sm.compile(opScript);
-
- double[] x1 = new double[] {36, 2, 1, 2, 0, 436, 1, 0, 0, 13, 0, 567, 1, 595, 2, 1};
- sm.eval(x1);
- assertEquals(0.d, sm.getResult().doubleValue(), 0d);
-
- double[] x2 = {31, 2, 1, 2, 0, 354, 1, 0, 0, 30, 0, 502, 1, 9, 2, 2};
- sm.eval(x2);
- assertEquals(1.d, sm.getResult().doubleValue(), 0d);
-
- double[] x3 = {39, 0, 0, 0, 0, 1756, 0, 0, 0, 3, 0, 939, 1, 0, 0, 0};
- sm.eval(x3);
- assertEquals(0.d, sm.getResult().doubleValue(), 0d);
- }
-
- private static void debugPrint(Object msg) {
- if (DEBUG) {
- System.out.println(msg);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
deleted file mode 100644
index 177a345..0000000
--- a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class DoubleArray3DTest {
-
- @Test
- public void test() {
- final int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
- @Test
- public void testConfigureExpand() {
- int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- size_i = 101;
- size_j = 101;
- size_k = 11;
- mdarray.configure(size_i, size_j, size_k);
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity());
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
- @Test
- public void testConfigureShrink() {
- int size_i = 50, size_j = 50, size_k = 5;
-
- final DoubleArray3D mdarray = new DoubleArray3D();
- mdarray.configure(size_i, size_j, size_k);
-
- final Random rand = new Random(31L);
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- int capacity = mdarray.getCapacity();
- size_i = 49;
- size_j = 49;
- size_k = 4;
- mdarray.configure(size_i, size_j, size_k);
- Assert.assertEquals(capacity, mdarray.getCapacity());
- Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
-
- final double[][][] data = new double[size_i][size_j][size_j];
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- double v = rand.nextDouble();
- data[i][j][k] = v;
- mdarray.set(i, j, k, v);
- }
- }
- }
-
- for (int i = 0; i < size_i; i++) {
- for (int j = 0; j < size_j; j++) {
- for (int k = 0; k < size_k; k++) {
- Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
- }
- }
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
deleted file mode 100644
index 72e76e8..0000000
--- a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class DoubleArrayTest {
-
- @Test
- public void testSparseDoubleArrayToArray() {
- SparseDoubleArray array = new SparseDoubleArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(10, array.size());
- Assert.assertEquals(10, array.toArray(false).length);
-
- double[] copied = array.toArray(true);
- Assert.assertEquals(10, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i], 0.d);
- }
- }
-
- @Test
- public void testSparseDoubleArrayClear() {
- SparseDoubleArray array = new SparseDoubleArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- array.clear();
- Assert.assertEquals(0, array.size());
- Assert.assertEquals(0, array.get(0), 0.d);
- for (int i = 0; i < 5; i++) {
- array.put(i, 100 + i);
- }
- Assert.assertEquals(5, array.size());
- for (int i = 0; i < 5; i++) {
- Assert.assertEquals(100 + i, array.get(i), 0.d);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
deleted file mode 100644
index 8a8a68d..0000000
--- a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class Int2FloatOpenHashMapTest {
-
- @Test
- public void testSize() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- map.put(1, 3.f);
- Assert.assertEquals(3.f, map.get(1), 0.d);
- map.put(1, 5.f);
- Assert.assertEquals(5.f, map.get(1), 0.d);
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testDefaultReturnValue() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- Assert.assertEquals(0, map.size());
- Assert.assertEquals(-1.f, map.get(1), 0.d);
- float ret = Float.MIN_VALUE;
- map.defaultReturnValue(ret);
- Assert.assertEquals(ret, map.get(1), 0.d);
- }
-
- @Test
- public void testPutAndGet() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Float v = map.get(i);
- Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d);
- }
- }
-
- @Test
- public void testIterator() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000);
- Int2FloatOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- Float v = itor.getValue();
- Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIterator2() {
- Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100);
- map.put(33, 3.16f);
-
- Int2FloatOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- Assert.assertNotEquals(-1, itor.next());
- Assert.assertEquals(33, itor.getKey());
- Assert.assertEquals(3.16f, itor.getValue(), 0.d);
- Assert.assertEquals(-1, itor.next());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
deleted file mode 100644
index 1186bdf..0000000
--- a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ObjectUtils;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class Int2LongOpenHashMapTest {
-
- @Test
- public void testSize() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- map.put(1, 3L);
- Assert.assertEquals(3L, map.get(1));
- map.put(1, 5L);
- Assert.assertEquals(5L, map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testDefaultReturnValue() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- Assert.assertEquals(0, map.size());
- Assert.assertEquals(-1L, map.get(1));
- long ret = Long.MIN_VALUE;
- map.defaultReturnValue(ret);
- Assert.assertEquals(ret, map.get(1));
- }
-
- @Test
- public void testPutAndGet() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- long v = map.get(i);
- Assert.assertEquals(i, v);
- }
- }
-
- @Test
- public void testSerde() throws IOException, ClassNotFoundException {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
-
- byte[] b = ObjectUtils.toCompressedBytes(map);
- map = new Int2LongOpenHashTable(16384);
- ObjectUtils.readCompressedObject(b, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- long v = map.get(i);
- Assert.assertEquals(i, v);
- }
- }
-
- @Test
- public void testIterator() {
- Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000);
- Int2LongOpenHashTable.IMapIterator itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertEquals(-1L, map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- long v = itor.getValue();
- Assert.assertEquals(k, v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
deleted file mode 100644
index 42852ea..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntArrayTest {
-
- @Test
- public void testFixedIntArrayToArray() {
- FixedIntArray array = new FixedIntArray(11);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(11, array.size());
- Assert.assertEquals(11, array.toArray(false).length);
-
- int[] copied = array.toArray(true);
- Assert.assertEquals(11, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i]);
- }
- }
-
- @Test
- public void testSparseIntArrayToArray() {
- SparseIntArray array = new SparseIntArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- Assert.assertEquals(10, array.size());
- Assert.assertEquals(10, array.toArray(false).length);
-
- int[] copied = array.toArray(true);
- Assert.assertEquals(10, copied.length);
- for (int i = 0; i < 10; i++) {
- Assert.assertEquals(10 + i, copied[i]);
- }
- }
-
- @Test
- public void testSparseIntArrayClear() {
- SparseIntArray array = new SparseIntArray(3);
- for (int i = 0; i < 10; i++) {
- array.put(i, 10 + i);
- }
- array.clear();
- Assert.assertEquals(0, array.size());
- Assert.assertEquals(0, array.get(0));
- for (int i = 0; i < 5; i++) {
- array.put(i, 100 + i);
- }
- Assert.assertEquals(5, array.size());
- for (int i = 0; i < 5; i++) {
- Assert.assertEquals(100 + i, array.get(i));
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
deleted file mode 100644
index 29a5a81..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntOpenHashMapTest {
-
- @Test
- public void testSize() {
- IntOpenHashMap<Float> map = new IntOpenHashMap<Float>(16384);
- map.put(1, Float.valueOf(3.f));
- Assert.assertEquals(Float.valueOf(3.f), map.get(1));
- map.put(1, Float.valueOf(5.f));
- Assert.assertEquals(Float.valueOf(5.f), map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testPutAndGet() {
- IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Integer v = map.get(i);
- Assert.assertEquals(i, v.intValue());
- }
- }
-
- @Test
- public void testIterator() {
- IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(1000);
- IntOpenHashMap.IMapIterator<Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- int k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(k, v.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
deleted file mode 100644
index 3babb3d..0000000
--- a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class IntOpenHashTableTest {
-
- @Test
- public void testSize() {
- IntOpenHashTable<Float> map = new IntOpenHashTable<Float>(16384);
- map.put(1, Float.valueOf(3.f));
- Assert.assertEquals(Float.valueOf(3.f), map.get(1));
- map.put(1, Float.valueOf(5.f));
- Assert.assertEquals(Float.valueOf(5.f), map.get(1));
- Assert.assertEquals(1, map.size());
- }
-
- @Test
- public void testPutAndGet() {
- IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(16384);
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- Assert.assertNull(map.put(i, i));
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Integer v = map.get(i);
- Assert.assertEquals(i, v.intValue());
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
deleted file mode 100644
index e3cc018..0000000
--- a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.mutable.MutableInt;
-
-import java.util.Map;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class OpenHashMapTest {
-
- @Test
- public void testPutAndGet() {
- Map<Object, Object> map = new OpenHashMap<Object, Object>(16384);
- final int numEntries = 5000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
- @Test
- public void testIterator() {
- OpenHashMap<String, Integer> map = new OpenHashMap<String, Integer>(1000);
- IMapIterator<String, Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(Integer.valueOf(k), v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIteratorGetProbe() {
- OpenHashMap<String, MutableInt> map = new OpenHashMap<String, MutableInt>(100);
- IMapIterator<String, MutableInt> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), new MutableInt(i));
- }
-
- final MutableInt probe = new MutableInt();
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- itor.getValue(probe);
- Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
deleted file mode 100644
index d5a465c..0000000
--- a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ObjectUtils;
-import hivemall.utils.lang.mutable.MutableInt;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class OpenHashTableTest {
-
- @Test
- public void testPutAndGet() {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 5000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
- @Test
- public void testIterator() {
- OpenHashTable<String, Integer> map = new OpenHashTable<String, Integer>(1000);
- IMapIterator<String, Integer> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- Integer v = itor.getValue();
- Assert.assertEquals(Integer.valueOf(k), v);
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testIteratorGetProbe() {
- OpenHashTable<String, MutableInt> map = new OpenHashTable<String, MutableInt>(100);
- IMapIterator<String, MutableInt> itor = map.entries();
- Assert.assertFalse(itor.hasNext());
-
- final int numEntries = 1000000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), new MutableInt(i));
- }
-
- final MutableInt probe = new MutableInt();
- itor = map.entries();
- Assert.assertTrue(itor.hasNext());
- while (itor.hasNext()) {
- Assert.assertFalse(itor.next() == -1);
- String k = itor.getKey();
- itor.getValue(probe);
- Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue());
- }
- Assert.assertEquals(-1, itor.next());
- }
-
- @Test
- public void testSerDe() throws IOException, ClassNotFoundException {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 100000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- byte[] serialized = ObjectUtils.toBytes(map);
- map = new OpenHashTable<Object, Object>();
- ObjectUtils.readObject(serialized, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
-
- @Test
- public void testCompressedSerDe() throws IOException, ClassNotFoundException {
- OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384);
- final int numEntries = 100000;
- for (int i = 0; i < numEntries; i++) {
- map.put(Integer.toString(i), i);
- }
-
- byte[] serialized = ObjectUtils.toCompressedBytes(map);
- map = new OpenHashTable<Object, Object>();
- ObjectUtils.readCompressedObject(serialized, map);
-
- Assert.assertEquals(numEntries, map.size());
- for (int i = 0; i < numEntries; i++) {
- Object v = map.get(Integer.toString(i));
- Assert.assertEquals(i, v);
- }
- map.put(Integer.toString(1), Integer.MAX_VALUE);
- Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1)));
- Assert.assertEquals(numEntries, map.size());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
deleted file mode 100644
index 68d0f6d..0000000
--- a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class SparseIntArrayTest {
-
- @Test
- public void testDense() {
- int size = 1000;
- Random rand = new Random(31);
- int[] expected = new int[size];
- IntArray actual = new SparseIntArray(10);
- for (int i = 0; i < size; i++) {
- int r = rand.nextInt(size);
- expected[i] = r;
- actual.put(i, r);
- }
- for (int i = 0; i < size; i++) {
- Assert.assertEquals(expected[i], actual.get(i));
- }
- }
-
- @Test
- public void testSparse() {
- int size = 1000;
- Random rand = new Random(31);
- int[] expected = new int[size];
- SparseIntArray actual = new SparseIntArray(10);
- for (int i = 0; i < size; i++) {
- int key = rand.nextInt(size);
- int v = rand.nextInt();
- expected[key] = v;
- actual.put(key, v);
- }
- for (int i = 0; i < actual.size(); i++) {
- int key = actual.keyAt(i);
- Assert.assertEquals(expected[key], actual.get(key, 0));
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
new file mode 100644
index 0000000..4fdb43e
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.DoubleArray3D;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoubleArray3DTest {
+
+ @Test
+ public void test() {
+ final int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testConfigureExpand() {
+ int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ size_i = 101;
+ size_j = 101;
+ size_k = 11;
+ mdarray.configure(size_i, size_j, size_k);
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity());
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testConfigureShrink() {
+ int size_i = 50, size_j = 50, size_k = 5;
+
+ final DoubleArray3D mdarray = new DoubleArray3D();
+ mdarray.configure(size_i, size_j, size_k);
+
+ final Random rand = new Random(31L);
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ int capacity = mdarray.getCapacity();
+ size_i = 49;
+ size_j = 49;
+ size_k = 4;
+ mdarray.configure(size_i, size_j, size_k);
+ Assert.assertEquals(capacity, mdarray.getCapacity());
+ Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize());
+
+ final double[][][] data = new double[size_i][size_j][size_j];
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ double v = rand.nextDouble();
+ data[i][j][k] = v;
+ mdarray.set(i, j, k, v);
+ }
+ }
+ }
+
+ for (int i = 0; i < size_i; i++) {
+ for (int j = 0; j < size_j; j++) {
+ for (int k = 0; k < size_k; k++) {
+ Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d);
+ }
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
new file mode 100644
index 0000000..ab52717
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.SparseDoubleArray;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoubleArrayTest {
+
+ @Test
+ public void testSparseDoubleArrayToArray() {
+ SparseDoubleArray array = new SparseDoubleArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(10, array.size());
+ Assert.assertEquals(10, array.toArray(false).length);
+
+ double[] copied = array.toArray(true);
+ Assert.assertEquals(10, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i], 0.d);
+ }
+ }
+
+ @Test
+ public void testSparseDoubleArrayClear() {
+ SparseDoubleArray array = new SparseDoubleArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ array.clear();
+ Assert.assertEquals(0, array.size());
+ Assert.assertEquals(0, array.get(0), 0.d);
+ for (int i = 0; i < 5; i++) {
+ array.put(i, 100 + i);
+ }
+ Assert.assertEquals(5, array.size());
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(100 + i, array.get(i), 0.d);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
new file mode 100644
index 0000000..0ce3912
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.DenseIntArray;
+import hivemall.utils.collections.arrays.SparseIntArray;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class IntArrayTest {
+
+ @Test
+ public void testFixedIntArrayToArray() {
+ DenseIntArray array = new DenseIntArray(11);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(11, array.size());
+ Assert.assertEquals(11, array.toArray(false).length);
+
+ int[] copied = array.toArray(true);
+ Assert.assertEquals(11, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i]);
+ }
+ }
+
+ @Test
+ public void testSparseIntArrayToArray() {
+ SparseIntArray array = new SparseIntArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ Assert.assertEquals(10, array.size());
+ Assert.assertEquals(10, array.toArray(false).length);
+
+ int[] copied = array.toArray(true);
+ Assert.assertEquals(10, copied.length);
+ for (int i = 0; i < 10; i++) {
+ Assert.assertEquals(10 + i, copied[i]);
+ }
+ }
+
+ @Test
+ public void testSparseIntArrayClear() {
+ SparseIntArray array = new SparseIntArray(3);
+ for (int i = 0; i < 10; i++) {
+ array.put(i, 10 + i);
+ }
+ array.clear();
+ Assert.assertEquals(0, array.size());
+ Assert.assertEquals(0, array.get(0));
+ for (int i = 0; i < 5; i++) {
+ array.put(i, 100 + i);
+ }
+ Assert.assertEquals(5, array.size());
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(100 + i, array.get(i));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
new file mode 100644
index 0000000..db3c8eb
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.arrays;
+
+import hivemall.utils.collections.arrays.IntArray;
+import hivemall.utils.collections.arrays.SparseIntArray;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SparseIntArrayTest {
+
+ @Test
+ public void testDense() {
+ int size = 1000;
+ Random rand = new Random(31);
+ int[] expected = new int[size];
+ IntArray actual = new SparseIntArray(10);
+ for (int i = 0; i < size; i++) {
+ int r = rand.nextInt(size);
+ expected[i] = r;
+ actual.put(i, r);
+ }
+ for (int i = 0; i < size; i++) {
+ Assert.assertEquals(expected[i], actual.get(i));
+ }
+ }
+
+ @Test
+ public void testSparse() {
+ int size = 1000;
+ Random rand = new Random(31);
+ int[] expected = new int[size];
+ SparseIntArray actual = new SparseIntArray(10);
+ for (int i = 0; i < size; i++) {
+ int key = rand.nextInt(size);
+ int v = rand.nextInt();
+ expected[key] = v;
+ actual.put(key, v);
+ }
+ for (int i = 0; i < actual.size(); i++) {
+ int key = actual.keyAt(i);
+ Assert.assertEquals(expected[key], actual.get(key, 0));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
new file mode 100644
index 0000000..c40ea7e
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.lists;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class LongArrayListTest {
+
+ @Test
+ public void testRemoveIndex() {
+ LongArrayList list = new LongArrayList();
+ list.add(0).add(1).add(2).add(3);
+ Assert.assertEquals(1, list.remove(1));
+ Assert.assertEquals(3, list.size());
+ Assert.assertArrayEquals(new long[] {0, 2, 3}, list.toArray());
+ Assert.assertEquals(3, list.remove(2));
+ Assert.assertArrayEquals(new long[] {0, 2}, list.toArray());
+ Assert.assertEquals(0, list.remove(0));
+ Assert.assertArrayEquals(new long[] {2}, list.toArray());
+ list.add(0).add(1);
+ Assert.assertEquals(3, list.size());
+ Assert.assertArrayEquals(new long[] {2, 0, 1}, list.toArray());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
new file mode 100644
index 0000000..6a2ff96
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class Int2FloatOpenHashMapTest {
+
+ @Test
+ public void testSize() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ map.put(1, 3.f);
+ Assert.assertEquals(3.f, map.get(1), 0.d);
+ map.put(1, 5.f);
+ Assert.assertEquals(5.f, map.get(1), 0.d);
+ Assert.assertEquals(1, map.size());
+ }
+
+ @Test
+ public void testDefaultReturnValue() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ Assert.assertEquals(0, map.size());
+ Assert.assertEquals(-1.f, map.get(1), 0.d);
+ float ret = Float.MIN_VALUE;
+ map.defaultReturnValue(ret);
+ Assert.assertEquals(ret, map.get(1), 0.d);
+ }
+
+ @Test
+ public void testPutAndGet() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384);
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
+ }
+ Assert.assertEquals(numEntries, map.size());
+ for (int i = 0; i < numEntries; i++) {
+ Float v = map.get(i);
+ Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d);
+ }
+ }
+
+ @Test
+ public void testIterator() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000);
+ Int2FloatOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertFalse(itor.hasNext());
+
+ final int numEntries = 1000000;
+ for (int i = 0; i < numEntries; i++) {
+ Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d);
+ }
+ Assert.assertEquals(numEntries, map.size());
+
+ itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ while (itor.hasNext()) {
+ Assert.assertFalse(itor.next() == -1);
+ int k = itor.getKey();
+ Float v = itor.getValue();
+ Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d);
+ }
+ Assert.assertEquals(-1, itor.next());
+ }
+
+ @Test
+ public void testIterator2() {
+ Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100);
+ map.put(33, 3.16f);
+
+ Int2FloatOpenHashTable.IMapIterator itor = map.entries();
+ Assert.assertTrue(itor.hasNext());
+ Assert.assertNotEquals(-1, itor.next());
+ Assert.assertEquals(33, itor.getKey());
+ Assert.assertEquals(3.16f, itor.getValue(), 0.d);
+ Assert.assertEquals(-1, itor.next());
+ }
+
+}
[07/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java
deleted file mode 100644
index a06cdb0..0000000
--- a/core/src/main/java/hivemall/utils/collections/Int2FloatOpenHashTable.java
+++ /dev/null
@@ -1,418 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-/**
- * An open-addressing hash table with double hashing
- *
- * @see http://en.wikipedia.org/wiki/Double_hashing
- */
-public class Int2FloatOpenHashTable implements Externalizable {
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
- private static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected final transient float _loadFactor;
- protected final transient float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
- protected float defaultReturnValue = -1.f;
-
- protected int[] _keys;
- protected float[] _values;
- protected byte[] _states;
-
- protected Int2FloatOpenHashTable(int size, float loadFactor, float growFactor,
- boolean forcePrime) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
- this._keys = new int[actualSize];
- this._values = new float[actualSize];
- this._states = new byte[actualSize];
- this._threshold = (int) (actualSize * _loadFactor);
- }
-
- public Int2FloatOpenHashTable(int size, float loadFactor, float growFactor) {
- this(size, loadFactor, growFactor, true);
- }
-
- public Int2FloatOpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
- }
-
- /**
- * Only for {@link Externalizable}
- */
- public Int2FloatOpenHashTable() {// required for serialization
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- }
-
- public void defaultReturnValue(float v) {
- this.defaultReturnValue = v;
- }
-
- public boolean containsKey(int key) {
- return findKey(key) >= 0;
- }
-
- /**
- * @return -1.f if not found
- */
- public float get(int key) {
- int i = findKey(key);
- if (i < 0) {
- return defaultReturnValue;
- }
- return _values[i];
- }
-
- public float put(int key, float value) {
- int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// double hashing
- if (keys[keyIdx] == key) {
- float old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- float old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- return defaultReturnValue;
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// too filled
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public float remove(int key) {
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- float old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return defaultReturnValue;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- float old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- }
- }
- return defaultReturnValue;
- }
-
- public int size() {
- return _used;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- public IMapIterator entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- IMapIterator i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- protected void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- private void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- int[] newkeys = new int[newCapacity];
- float[] newValues = new float[newCapacity];
- byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
- used++;
- int k = _keys[i];
- float v = _values[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- newStates[keyIdx] = FULL;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(int key) {
- return key & 0x7fffffff;
- }
-
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(_threshold);
- out.writeInt(_used);
-
- out.writeInt(_keys.length);
- IMapIterator i = entries();
- while (i.next() != -1) {
- out.writeInt(i.getKey());
- out.writeFloat(i.getValue());
- }
- }
-
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._threshold = in.readInt();
- this._used = in.readInt();
-
- int keylen = in.readInt();
- int[] keys = new int[keylen];
- float[] values = new float[keylen];
- byte[] states = new byte[keylen];
- for (int i = 0; i < _used; i++) {
- int k = in.readInt();
- float v = in.readFloat();
- int hash = keyHash(k);
- int keyIdx = hash % keylen;
- if (states[keyIdx] != FREE) {// second hash
- int decr = 1 + (hash % (keylen - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keylen;
- }
- if (states[keyIdx] == FREE) {
- break;
- }
- }
- }
- states[keyIdx] = FULL;
- keys[keyIdx] = k;
- values[keyIdx] = v;
- }
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public interface IMapIterator {
-
- public boolean hasNext();
-
- /**
- * @return -1 if not found
- */
- public int next();
-
- public int getKey();
-
- public float getValue();
-
- }
-
- private final class MapIterator implements IMapIterator {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- public int getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public float getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/Int2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/Int2IntOpenHashTable.java
deleted file mode 100644
index 211157e..0000000
--- a/core/src/main/java/hivemall/utils/collections/Int2IntOpenHashTable.java
+++ /dev/null
@@ -1,414 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-/**
- * An open-addressing hash table with double hashing
- *
- * @see http://en.wikipedia.org/wiki/Double_hashing
- */
-public final class Int2IntOpenHashTable implements Externalizable {
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
- private static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected final transient float _loadFactor;
- protected final transient float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
- protected int defaultReturnValue = -1;
-
- protected int[] _keys;
- protected int[] _values;
- protected byte[] _states;
-
- protected Int2IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
- this._keys = new int[actualSize];
- this._values = new int[actualSize];
- this._states = new byte[actualSize];
- this._threshold = (int) (actualSize * _loadFactor);
- }
-
- public Int2IntOpenHashTable(int size, int loadFactor, int growFactor) {
- this(size, loadFactor, growFactor, true);
- }
-
- public Int2IntOpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
- }
-
- public Int2IntOpenHashTable() {// required for serialization
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- }
-
- public void defaultReturnValue(int v) {
- this.defaultReturnValue = v;
- }
-
- public boolean containsKey(final int key) {
- return findKey(key) >= 0;
- }
-
- /**
- * @return -1.f if not found
- */
- public int get(final int key) {
- final int i = findKey(key);
- if (i < 0) {
- return defaultReturnValue;
- }
- return _values[i];
- }
-
- public int put(final int key, final int value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final int[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// double hashing
- if (keys[keyIdx] == key) {
- int old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- int old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- return defaultReturnValue;
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(final int index, final int key) {
- final byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(final int index) {
- if ((_used + 1) >= _threshold) {// too filled
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected int findKey(final int key) {
- final int[] keys = _keys;
- final byte[] states = _states;
- final int keyLength = keys.length;
-
- final int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public int remove(final int key) {
- final int[] keys = _keys;
- final int[] values = _values;
- final byte[] states = _states;
- final int keyLength = keys.length;
-
- final int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- int old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return defaultReturnValue;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- int old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- }
- }
- return defaultReturnValue;
- }
-
- public int size() {
- return _used;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- public IMapIterator entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- IMapIterator i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- protected void ensureCapacity(final int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- private void rehash(final int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- final int[] newkeys = new int[newCapacity];
- final int[] newValues = new int[newCapacity];
- final byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
- used++;
- int k = _keys[i];
- int v = _values[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- newStates[keyIdx] = FULL;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(final int key) {
- return key & 0x7fffffff;
- }
-
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(_threshold);
- out.writeInt(_used);
-
- out.writeInt(_keys.length);
- IMapIterator i = entries();
- while (i.next() != -1) {
- out.writeInt(i.getKey());
- out.writeInt(i.getValue());
- }
- }
-
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._threshold = in.readInt();
- this._used = in.readInt();
-
- final int keylen = in.readInt();
- final int[] keys = new int[keylen];
- final int[] values = new int[keylen];
- final byte[] states = new byte[keylen];
- for (int i = 0; i < _used; i++) {
- int k = in.readInt();
- int v = in.readInt();
- int hash = keyHash(k);
- int keyIdx = hash % keylen;
- if (states[keyIdx] != FREE) {// second hash
- int decr = 1 + (hash % (keylen - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keylen;
- }
- if (states[keyIdx] == FREE) {
- break;
- }
- }
- }
- states[keyIdx] = FULL;
- keys[keyIdx] = k;
- values[keyIdx] = v;
- }
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public interface IMapIterator {
-
- public boolean hasNext();
-
- /**
- * @return -1 if not found
- */
- public int next();
-
- public int getKey();
-
- public int getValue();
-
- }
-
- private final class MapIterator implements IMapIterator {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- public int getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public int getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/Int2LongOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/Int2LongOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/Int2LongOpenHashTable.java
deleted file mode 100644
index 2c229a4..0000000
--- a/core/src/main/java/hivemall/utils/collections/Int2LongOpenHashTable.java
+++ /dev/null
@@ -1,500 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.codec.VariableByteCodec;
-import hivemall.utils.codec.ZigZagLEB128Codec;
-import hivemall.utils.math.Primes;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-
-/**
- * An open-addressing hash table with double hashing
- *
- * @see http://en.wikipedia.org/wiki/Double_hashing
- */
-public class Int2LongOpenHashTable implements Externalizable {
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- public static final int DEFAULT_SIZE = 65536;
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
- public static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected final transient float _loadFactor;
- protected final transient float _growFactor;
-
- protected int[] _keys;
- protected long[] _values;
- protected byte[] _states;
-
- protected int _used;
- protected int _threshold;
- protected long defaultReturnValue = -1L;
-
- /**
- * Constructor for Externalizable. Should not be called otherwise.
- */
- public Int2LongOpenHashTable() {// for Externalizable
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- }
-
- public Int2LongOpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
- }
-
- public Int2LongOpenHashTable(int size, float loadFactor, float growFactor) {
- this(size, loadFactor, growFactor, true);
- }
-
- protected Int2LongOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
- this._keys = new int[actualSize];
- this._values = new long[actualSize];
- this._states = new byte[actualSize];
- this._used = 0;
- this._threshold = (int) (actualSize * _loadFactor);
- }
-
- public Int2LongOpenHashTable(@Nonnull int[] keys, @Nonnull long[] values,
- @Nonnull byte[] states, int used) {
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- this._keys = keys;
- this._values = values;
- this._states = states;
- this._used = used;
- this._threshold = keys.length;
- }
-
- @Nonnull
- public static Int2LongOpenHashTable newInstance() {
- return new Int2LongOpenHashTable(DEFAULT_SIZE);
- }
-
- public void defaultReturnValue(long v) {
- this.defaultReturnValue = v;
- }
-
- @Nonnull
- public int[] getKeys() {
- return _keys;
- }
-
- @Nonnull
- public long[] getValues() {
- return _values;
- }
-
- @Nonnull
- public byte[] getStates() {
- return _states;
- }
-
- public boolean containsKey(int key) {
- return findKey(key) >= 0;
- }
-
- /**
- * @return -1.f if not found
- */
- public long get(int key) {
- int i = findKey(key);
- if (i < 0) {
- return defaultReturnValue;
- }
- return _values[i];
- }
-
- public long put(int key, long value) {
- int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- int[] keys = _keys;
- long[] values = _values;
- byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// double hashing
- if (keys[keyIdx] == key) {
- long old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- long old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- return defaultReturnValue;
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// too filled
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public long remove(int key) {
- int[] keys = _keys;
- long[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- long old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return defaultReturnValue;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- long old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- }
- }
- return defaultReturnValue;
- }
-
- public int size() {
- return _used;
- }
-
- public int capacity() {
- return _keys.length;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- public IMapIterator entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- IMapIterator i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- protected void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- private void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- int[] newkeys = new int[newCapacity];
- long[] newValues = new long[newCapacity];
- byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
- used++;
- int k = _keys[i];
- long v = _values[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- newStates[keyIdx] = FULL;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(int key) {
- return key & 0x7fffffff;
- }
-
- @Override
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(_threshold);
- out.writeInt(_used);
-
- final int[] keys = _keys;
- final int size = keys.length;
- out.writeInt(size);
-
- final byte[] states = _states;
- writeStates(states, out);
-
- final long[] values = _values;
- for (int i = 0; i < size; i++) {
- if (states[i] != FULL) {
- continue;
- }
- ZigZagLEB128Codec.writeSignedInt(keys[i], out);
- ZigZagLEB128Codec.writeSignedLong(values[i], out);
- }
- }
-
- @Nonnull
- private static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out)
- throws IOException {
- // write empty states's indexes differentially
- final int size = status.length;
- int cardinarity = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != FULL) {
- cardinarity++;
- }
- }
- out.writeInt(cardinarity);
- if (cardinarity == 0) {
- return;
- }
- int prev = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != FULL) {
- int diff = i - prev;
- assert (diff >= 0);
- VariableByteCodec.encodeUnsignedInt(diff, out);
- prev = i;
- }
- }
- }
-
- @Override
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._threshold = in.readInt();
- this._used = in.readInt();
-
- final int size = in.readInt();
- final int[] keys = new int[size];
- final long[] values = new long[size];
- final byte[] states = new byte[size];
- readStates(in, states);
-
- for (int i = 0; i < size; i++) {
- if (states[i] != FULL) {
- continue;
- }
- keys[i] = ZigZagLEB128Codec.readSignedInt(in);
- values[i] = ZigZagLEB128Codec.readSignedLong(in);
- }
-
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- @Nonnull
- private static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status)
- throws IOException {
- // read non-empty states differentially
- final int cardinarity = in.readInt();
- Arrays.fill(status, IntOpenHashTable.FULL);
- int prev = 0;
- for (int j = 0; j < cardinarity; j++) {
- int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
- status[i] = IntOpenHashTable.FREE;
- prev = i;
- }
- }
-
- public interface IMapIterator {
-
- public boolean hasNext();
-
- /**
- * @return -1 if not found
- */
- public int next();
-
- public int getKey();
-
- public long getValue();
-
- }
-
- private final class MapIterator implements IMapIterator {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- public int getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public long getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/IntArray.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/IntArray.java b/core/src/main/java/hivemall/utils/collections/IntArray.java
deleted file mode 100644
index cb6b0b8..0000000
--- a/core/src/main/java/hivemall/utils/collections/IntArray.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.io.Serializable;
-
-import javax.annotation.Nonnull;
-
-public interface IntArray extends Serializable {
-
- public int get(int key);
-
- public int get(int key, int valueIfKeyNotFound);
-
- public void put(int key, int value);
-
- public int size();
-
- public int keyAt(int index);
-
- @Nonnull
- public int[] toArray();
-
- @Nonnull
- public int[] toArray(boolean copy);
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/IntArrayList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/IntArrayList.java b/core/src/main/java/hivemall/utils/collections/IntArrayList.java
deleted file mode 100644
index 0716ca8..0000000
--- a/core/src/main/java/hivemall/utils/collections/IntArrayList.java
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.ArrayUtils;
-
-import java.io.Closeable;
-import java.io.Serializable;
-
-import javax.annotation.Nonnull;
-
-public final class IntArrayList implements Serializable, Closeable {
- private static final long serialVersionUID = -2147675120406747488L;
- public static final int DEFAULT_CAPACITY = 12;
-
- /** array entity */
- private int[] data;
- private int used;
-
- public IntArrayList() {
- this(DEFAULT_CAPACITY);
- }
-
- public IntArrayList(int size) {
- this.data = new int[size];
- this.used = 0;
- }
-
- public IntArrayList(int[] initValues) {
- this.data = initValues;
- this.used = initValues.length;
- }
-
- public void add(final int value) {
- if (used >= data.length) {
- expand(used + 1);
- }
- data[used++] = value;
- }
-
- public void add(final int[] values) {
- final int needs = used + values.length;
- if (needs >= data.length) {
- expand(needs);
- }
- System.arraycopy(values, 0, data, used, values.length);
- this.used = needs;
- }
-
- /**
- * dynamic expansion.
- */
- private void expand(final int max) {
- while (data.length < max) {
- final int len = data.length;
- int[] newArray = new int[len * 2];
- System.arraycopy(data, 0, newArray, 0, len);
- this.data = newArray;
- }
- }
-
- public int remove() {
- return data[--used];
- }
-
- public int remove(final int index) {
- final int ret;
- if (index > used) {
- throw new IndexOutOfBoundsException();
- } else if (index == used) {
- ret = data[--used];
- } else { // index < used
- // removed value
- ret = data[index];
- final int[] newarray = new int[--used];
- // prefix
- System.arraycopy(data, 0, newarray, 0, index - 1);
- // appendix
- System.arraycopy(data, index + 1, newarray, index, used - index);
- // set fields.
- this.data = newarray;
- }
- return ret;
- }
-
- public void set(final int index, final int value) {
- if (index > used) {
- throw new IllegalArgumentException("Index " + index + " MUST be less than size() "
- + used);
- } else if (index == used) {
- ++used;
- }
- data[index] = value;
- }
-
- public int get(final int index) {
- if (index >= used) {
- throw new IndexOutOfBoundsException("Index " + index + " out of bounds " + used);
- }
- return data[index];
- }
-
- public int fastGet(final int index) {
- return data[index];
- }
-
- /**
- * @return -1 if not found.
- */
- public int indexOf(final int key) {
- return ArrayUtils.indexOf(data, key, 0, used);
- }
-
- public boolean contains(final int key) {
- return ArrayUtils.indexOf(data, key, 0, used) != -1;
- }
-
- public int size() {
- return used;
- }
-
- public boolean isEmpty() {
- return used == 0;
- }
-
- public void clear() {
- used = 0;
- }
-
- @Nonnull
- public int[] toArray() {
- return toArray(false);
- }
-
- @Nonnull
- public int[] toArray(boolean close) {
- final int[] newArray = new int[used];
- System.arraycopy(data, 0, newArray, 0, used);
- if (close) {
- close();
- }
- return newArray;
- }
-
- public int[] array() {
- return data;
- }
-
- @Override
- public String toString() {
- final StringBuilder buf = new StringBuilder();
- buf.append('[');
- for (int i = 0; i < used; i++) {
- if (i != 0) {
- buf.append(", ");
- }
- buf.append(data[i]);
- }
- buf.append(']');
- return buf.toString();
- }
-
- @Override
- public void close() {
- this.data = null;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/IntOpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/IntOpenHashMap.java
deleted file mode 100644
index 4621e6d..0000000
--- a/core/src/main/java/hivemall/utils/collections/IntOpenHashMap.java
+++ /dev/null
@@ -1,467 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-/**
- * An open-addressing hash table with double hashing
- *
- * @see http://en.wikipedia.org/wiki/Double_hashing
- */
-public class IntOpenHashMap<V> implements Externalizable {
- private static final long serialVersionUID = -8162355845665353513L;
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
- private static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected final transient float _loadFactor;
- protected final transient float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
-
- protected int[] _keys;
- protected V[] _values;
- protected byte[] _states;
-
- @SuppressWarnings("unchecked")
- protected IntOpenHashMap(int size, float loadFactor, float growFactor, boolean forcePrime) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
- this._keys = new int[actualSize];
- this._values = (V[]) new Object[actualSize];
- this._states = new byte[actualSize];
- this._threshold = Math.round(actualSize * _loadFactor);
- }
-
- public IntOpenHashMap(int size, float loadFactor, float growFactor) {
- this(size, loadFactor, growFactor, true);
- }
-
- public IntOpenHashMap(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
- }
-
- public IntOpenHashMap() {// required for serialization
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- }
-
- public boolean containsKey(int key) {
- return findKey(key) >= 0;
- }
-
- public final V get(final int key) {
- final int i = findKey(key);
- if (i < 0) {
- return null;
- }
- recordAccess(i);
- return _values[i];
- }
-
- public V put(final int key, final V value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- final boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// double hashing
- if (keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- recordAccess(keyIdx);
- return old;
- }
- // try second hash
- final int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- recordAccess(keyIdx);
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- postAddEntry(keyIdx);
- return null;
- }
-
- public V putIfAbsent(final int key, final V value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- final boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// second hashing
- if (keys[keyIdx] == key) {
- return values[keyIdx];
- }
- // try second hash
- final int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return values[keyIdx];
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- _used++;
- postAddEntry(keyIdx);
- return null;
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// too filled
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected void postAddEntry(int index) {}
-
- private int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public V remove(int key) {
- int[] keys = _keys;
- V[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- recordRemoval(keyIdx);
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return null;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- recordRemoval(keyIdx);
- return old;
- }
- }
- }
- return null;
- }
-
- public int size() {
- return _used;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- @SuppressWarnings("unchecked")
- public IMapIterator<V> entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- IMapIterator<V> i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- private void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- @SuppressWarnings("unchecked")
- protected void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- final int[] oldKeys = _keys;
- final V[] oldValues = _values;
- final byte[] oldStates = _states;
- int[] newkeys = new int[newCapacity];
- V[] newValues = (V[]) new Object[newCapacity];
- byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (oldStates[i] == FULL) {
- used++;
- int k = oldKeys[i];
- V v = oldValues[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- newStates[keyIdx] = FULL;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(int key) {
- return key & 0x7fffffff;
- }
-
- protected void recordAccess(int idx) {};
-
- protected void recordRemoval(int idx) {};
-
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(_threshold);
- out.writeInt(_used);
-
- out.writeInt(_keys.length);
- IMapIterator<V> i = entries();
- while (i.next() != -1) {
- out.writeInt(i.getKey());
- out.writeObject(i.getValue());
- }
- }
-
- @SuppressWarnings("unchecked")
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._threshold = in.readInt();
- this._used = in.readInt();
-
- int keylen = in.readInt();
- int[] keys = new int[keylen];
- V[] values = (V[]) new Object[keylen];
- byte[] states = new byte[keylen];
- for (int i = 0; i < _used; i++) {
- int k = in.readInt();
- V v = (V) in.readObject();
- int hash = keyHash(k);
- int keyIdx = hash % keylen;
- if (states[keyIdx] != FREE) {// second hash
- int decr = 1 + (hash % (keylen - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keylen;
- }
- if (states[keyIdx] == FREE) {
- break;
- }
- }
- }
- states[keyIdx] = FULL;
- keys[keyIdx] = k;
- values[keyIdx] = v;
- }
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public interface IMapIterator<V> {
-
- public boolean hasNext();
-
- public int next();
-
- public int getKey();
-
- public V getValue();
-
- }
-
- @SuppressWarnings("rawtypes")
- private final class MapIterator implements IMapIterator {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- public int getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public V getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/IntOpenHashTable.java
deleted file mode 100644
index 8d0cdf2..0000000
--- a/core/src/main/java/hivemall/utils/collections/IntOpenHashTable.java
+++ /dev/null
@@ -1,338 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-import java.util.HashMap;
-
-import javax.annotation.Nonnull;
-
-/**
- * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
- */
-public final class IntOpenHashTable<V> implements Externalizable {
-
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
- public static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- public static final byte FREE = 0;
- public static final byte FULL = 1;
- public static final byte REMOVED = 2;
-
- protected/* final */float _loadFactor;
- protected/* final */float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
-
- protected int[] _keys;
- protected V[] _values;
- protected byte[] _states;
-
- public IntOpenHashTable() {} // for Externalizable
-
- public IntOpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
- }
-
- @SuppressWarnings("unchecked")
- public IntOpenHashTable(int size, float loadFactor, float growFactor) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = Primes.findLeastPrimeNumber(size);
- this._keys = new int[actualSize];
- this._values = (V[]) new Object[actualSize];
- this._states = new byte[actualSize];
- this._threshold = Math.round(actualSize * _loadFactor);
- }
-
- public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[] states,
- int used) {
- this._used = used;
- this._threshold = keys.length;
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public int[] getKeys() {
- return _keys;
- }
-
- public Object[] getValues() {
- return _values;
- }
-
- public byte[] getStates() {
- return _states;
- }
-
- public boolean containsKey(final int key) {
- return findKey(key) >= 0;
- }
-
- public V get(final int key) {
- final int i = findKey(key);
- if (i < 0) {
- return null;
- }
- return _values[i];
- }
-
- public V put(final int key, final V value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {
- if (keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- return null;
- }
-
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// filled enough
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected int findKey(final int key) {
- final int[] keys = _keys;
- final byte[] states = _states;
- final int keyLength = keys.length;
-
- final int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public V remove(final int key) {
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
- final int keyLength = keys.length;
-
- final int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return null;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- return old;
- }
- }
- }
- return null;
- }
-
- public int size() {
- return _used;
- }
-
- public int capacity() {
- return _keys.length;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- protected void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- @SuppressWarnings("unchecked")
- private void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- final int[] newkeys = new int[newCapacity];
- final V[] newValues = (V[]) new Object[newCapacity];
- final byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
- used++;
- int k = _keys[i];
- V v = _values[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newStates[keyIdx] = FULL;
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(final int key) {
- return key & 0x7fffffff;
- }
-
- @Override
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeFloat(_loadFactor);
- out.writeFloat(_growFactor);
- out.writeInt(_used);
-
- final int size = _keys.length;
- out.writeInt(size);
-
- for (int i = 0; i < size; i++) {
- out.writeInt(_keys[i]);
- out.writeObject(_values[i]);
- out.writeByte(_states[i]);
- }
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._loadFactor = in.readFloat();
- this._growFactor = in.readFloat();
- this._used = in.readInt();
-
- final int size = in.readInt();
- final int[] keys = new int[size];
- final Object[] values = new Object[size];
- final byte[] states = new byte[size];
- for (int i = 0; i < size; i++) {
- keys[i] = in.readInt();
- values[i] = in.readObject();
- states[i] = in.readByte();
- }
- this._threshold = size;
- this._keys = keys;
- this._values = (V[]) values;
- this._states = states;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/LRUMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/LRUMap.java b/core/src/main/java/hivemall/utils/collections/LRUMap.java
deleted file mode 100644
index bfae4d7..0000000
--- a/core/src/main/java/hivemall/utils/collections/LRUMap.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections;
-
-import java.util.LinkedHashMap;
-import java.util.Map;
-
-public class LRUMap<K, V> extends LinkedHashMap<K, V> {
- private static final long serialVersionUID = -7708264099645977733L;
-
- private final int cacheSize;
-
- public LRUMap(int cacheSize) {
- this(cacheSize, 0.75f, cacheSize);
- }
-
- public LRUMap(int capacity, float loadFactor, int cacheSize) {
- super(capacity, loadFactor, true);
- this.cacheSize = cacheSize;
- }
-
- protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
- return size() > cacheSize;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/OpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/OpenHashMap.java
deleted file mode 100644
index b1f5765..0000000
--- a/core/src/main/java/hivemall/utils/collections/OpenHashMap.java
+++ /dev/null
@@ -1,350 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-//
-// Copyright (C) 2010 catchpole.net
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-package hivemall.utils.collections;
-
-import hivemall.utils.lang.Copyable;
-import hivemall.utils.math.MathUtils;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * An optimized Hashed Map implementation.
- * <p/>
- * <p>
- * This Hashmap does not allow nulls to be used as keys or values.
- * <p/>
- * <p>
- * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those
- * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking
- * operation.
- */
-public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
- private K[] keys;
- private V[] values;
-
- // total number of entries in this table
- private int size;
- // number of bits for the value table (eg. 8 bits = 256 entries)
- private int bits;
- // the number of bits in each sweep zone.
- private int sweepbits;
- // the size of a sweep (2 to the power of sweepbits)
- private int sweep;
- // the sweepmask used to create sweep zone offsets
- private int sweepmask;
-
- public OpenHashMap() {}// for Externalizable
-
- public OpenHashMap(int size) {
- resize(MathUtils.bitsRequired(size < 256 ? 256 : size));
- }
-
- public V put(K key, V value) {
- if (key == null) {
- throw new NullPointerException(this.getClass().getName() + " key");
- }
-
- for (;;) {
- int off = getBucketOffset(key);
- int end = off + sweep;
- for (; off < end; off++) {
- K searchKey = keys[off];
- if (searchKey == null) {
- // insert
- keys[off] = key;
- size++;
-
- V previous = values[off];
- values[off] = value;
- return previous;
- } else if (compare(searchKey, key)) {
- // replace
- V previous = values[off];
- values[off] = value;
- return previous;
- }
- }
- resize(this.bits + 1);
- }
- }
-
- public V get(Object key) {
- int off = getBucketOffset(key);
- int end = sweep + off;
- for (; off < end; off++) {
- if (keys[off] != null && compare(keys[off], key)) {
- return values[off];
- }
- }
- return null;
- }
-
- public V remove(Object key) {
- int off = getBucketOffset(key);
- int end = sweep + off;
- for (; off < end; off++) {
- if (keys[off] != null && compare(keys[off], key)) {
- keys[off] = null;
- V previous = values[off];
- values[off] = null;
- size--;
- return previous;
- }
- }
- return null;
- }
-
- public int size() {
- return size;
- }
-
- public void putAll(Map<? extends K, ? extends V> m) {
- for (K key : m.keySet()) {
- put(key, m.get(key));
- }
- }
-
- public boolean isEmpty() {
- return size == 0;
- }
-
- public boolean containsKey(Object key) {
- return get(key) != null;
- }
-
- public boolean containsValue(Object value) {
- for (V v : values) {
- if (v != null && compare(v, value)) {
- return true;
- }
- }
- return false;
- }
-
- public void clear() {
- Arrays.fill(keys, null);
- Arrays.fill(values, null);
- size = 0;
- }
-
- public Set<K> keySet() {
- Set<K> set = new HashSet<K>();
- for (K key : keys) {
- if (key != null) {
- set.add(key);
- }
- }
- return set;
- }
-
- public Collection<V> values() {
- Collection<V> list = new ArrayList<V>();
- for (V value : values) {
- if (value != null) {
- list.add(value);
- }
- }
- return list;
- }
-
- public Set<Entry<K, V>> entrySet() {
- Set<Entry<K, V>> set = new HashSet<Entry<K, V>>();
- for (K key : keys) {
- if (key != null) {
- set.add(new MapEntry<K, V>(this, key));
- }
- }
- return set;
- }
-
- private static final class MapEntry<K, V> implements Map.Entry<K, V> {
- private final Map<K, V> map;
- private final K key;
-
- public MapEntry(Map<K, V> map, K key) {
- this.map = map;
- this.key = key;
- }
-
- public K getKey() {
- return key;
- }
-
- public V getValue() {
- return map.get(key);
- }
-
- public V setValue(V value) {
- return map.put(key, value);
- }
- }
-
- public void writeExternal(ObjectOutput out) throws IOException {
- // remember the number of bits
- out.writeInt(this.bits);
- // remember the total number of entries
- out.writeInt(this.size);
- // write all entries
- for (int x = 0; x < this.keys.length; x++) {
- if (keys[x] != null) {
- out.writeObject(keys[x]);
- out.writeObject(values[x]);
- }
- }
- }
-
- @SuppressWarnings("unchecked")
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- // resize to old bit size
- int bitSize = in.readInt();
- if (bitSize != bits) {
- resize(bitSize);
- }
- // read all entries
- int size = in.readInt();
- for (int x = 0; x < size; x++) {
- this.put((K) in.readObject(), (V) in.readObject());
- }
- }
-
- @Override
- public String toString() {
- return this.getClass().getSimpleName() + ' ' + this.size;
- }
-
- @SuppressWarnings("unchecked")
- private void resize(int bits) {
- this.bits = bits;
- this.sweepbits = bits / 4;
- this.sweep = MathUtils.powerOf(2, sweepbits) * 4;
- this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits;
-
- // remember old values so we can recreate the entries
- K[] existingKeys = this.keys;
- V[] existingValues = this.values;
-
- // create the arrays
- this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep];
- this.keys = (K[]) new Object[values.length];
- this.size = 0;
-
- // re-add the previous entries if resizing
- if (existingKeys != null) {
- for (int x = 0; x < existingKeys.length; x++) {
- if (existingKeys[x] != null) {
- put(existingKeys[x], existingValues[x]);
- }
- }
- }
- }
-
- private int getBucketOffset(Object key) {
- return (key.hashCode() << this.sweepbits) & this.sweepmask;
- }
-
- private static boolean compare(final Object v1, final Object v2) {
- return v1 == v2 || v1.equals(v2);
- }
-
- public IMapIterator<K, V> entries() {
- return new MapIterator();
- }
-
- private final class MapIterator implements IMapIterator<K, V> {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < keys.length && keys[index] == null) {
- index++;
- }
- return index;
- }
-
- @Override
- public boolean hasNext() {
- return nextEntry < keys.length;
- }
-
- @Override
- public int next() {
- free(lastEntry);
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- @Override
- public K getKey() {
- return keys[lastEntry];
- }
-
- @Override
- public V getValue() {
- return values[lastEntry];
- }
-
- @Override
- public <T extends Copyable<V>> void getValue(T probe) {
- probe.copyFrom(getValue());
- }
-
- private void free(int index) {
- if (index >= 0) {
- keys[index] = null;
- values[index] = null;
- }
- }
-
- }
-}