You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/04/09 21:32:16 UTC
[04/12] incubator-hivemall git commit: Close #51: [HIVEMALL-75]
Support Sparse Vector Format as the input of RandomForest
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
new file mode 100644
index 0000000..6a7f39f
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with float hashing
+ *
+ * @see http://en.wikipedia.org/wiki/float_hashing
+ */
+public final class Long2FloatOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected float defaultReturnValue = 0.f;
+
+ protected long[] _keys;
+ protected float[] _values;
+ protected byte[] _states;
+
+ protected Long2FloatOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new long[actualSize];
+ this._values = new float[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Long2FloatOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Long2FloatOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Long2FloatOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(float v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final long key) {
+ return _findKey(key) >= 0;
+ }
+
+ /**
+ * @return defaultReturnValue if not found
+ */
+ public float get(final long key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public float get(final long key, final float defaultValue) {
+ final int i = _findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public float _get(final int index) {
+ if (index < 0) {
+ return defaultReturnValue;
+ }
+ return _values[index];
+ }
+
+ public float put(final long key, final float value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// float hashing
+ if (keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final long key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return -1 if not found
+ */
+ public int _findKey(final long key) {
+ final long[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public float remove(final long key) {
+ final long[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ float old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final long[] newkeys = new long[newCapacity];
+ final float[] newValues = new float[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ long k = _keys[i];
+ float v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final long key) {
+ return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeLong(i.getKey());
+ out.writeFloat(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final long[] keys = new long[keylen];
+ final float[] values = new float[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ long k = in.readLong();
+ float v = in.readFloat();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public long getKey();
+
+ public float getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public long getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public float getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
new file mode 100644
index 0000000..51b8f12
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table with double hashing
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public final class Long2IntOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected int defaultReturnValue = -1;
+
+ protected long[] _keys;
+ protected int[] _values;
+ protected byte[] _states;
+
+ protected Long2IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new long[actualSize];
+ this._values = new int[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Long2IntOpenHashTable(int size, int loadFactor, int growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Long2IntOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ public Long2IntOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(int v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final long key) {
+ return _findKey(key) >= 0;
+ }
+
+ /**
+ * @return defaultReturnValue if not found
+ */
+ public int get(final long key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public int get(final long key, final int defaultValue) {
+ final int i = _findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public int _get(final int index) {
+ if (index < 0) {
+ return defaultReturnValue;
+ }
+ return _values[index];
+ }
+
+ public int put(final long key, final int value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ public int incr(final long key, final int delta) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] += delta;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ values[keyIdx] += delta;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] += delta;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final long key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return -1 if not found
+ */
+ public int _findKey(final long key) {
+ final long[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public int remove(final long key) {
+ final long[] keys = _keys;
+ final int[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ int old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final long[] newkeys = new long[newCapacity];
+ final int[] newValues = new int[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ long k = _keys[i];
+ int v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final long key) {
+ return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeLong(i.getKey());
+ out.writeInt(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ final int keylen = in.readInt();
+ final long[] keys = new long[keylen];
+ final int[] values = new int[keylen];
+ final byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ long k = in.readLong();
+ int v = in.readInt();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public long getKey();
+
+ public int getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public long getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public int getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
new file mode 100644
index 0000000..152447a
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+//
+// Copyright (C) 2010 catchpole.net
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.MathUtils;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An optimized Hashed Map implementation.
+ * <p/>
+ * <p>
+ * This Hashmap does not allow nulls to be used as keys or values.
+ * <p/>
+ * <p>
+ * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those
+ * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking
+ * operation.
+ */
+public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
+ private K[] keys;
+ private V[] values;
+
+ // total number of entries in this table
+ private int size;
+ // number of bits for the value table (eg. 8 bits = 256 entries)
+ private int bits;
+ // the number of bits in each sweep zone.
+ private int sweepbits;
+ // the size of a sweep (2 to the power of sweepbits)
+ private int sweep;
+ // the sweepmask used to create sweep zone offsets
+ private int sweepmask;
+
+ public OpenHashMap() {}// for Externalizable
+
+ public OpenHashMap(int size) {
+ resize(MathUtils.bitsRequired(size < 256 ? 256 : size));
+ }
+
+ public V put(K key, V value) {
+ if (key == null) {
+ throw new NullPointerException(this.getClass().getName() + " key");
+ }
+
+ for (;;) {
+ int off = getBucketOffset(key);
+ int end = off + sweep;
+ for (; off < end; off++) {
+ K searchKey = keys[off];
+ if (searchKey == null) {
+ // insert
+ keys[off] = key;
+ size++;
+
+ V previous = values[off];
+ values[off] = value;
+ return previous;
+ } else if (compare(searchKey, key)) {
+ // replace
+ V previous = values[off];
+ values[off] = value;
+ return previous;
+ }
+ }
+ resize(this.bits + 1);
+ }
+ }
+
+ public V get(Object key) {
+ int off = getBucketOffset(key);
+ int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] != null && compare(keys[off], key)) {
+ return values[off];
+ }
+ }
+ return null;
+ }
+
+ public V remove(Object key) {
+ int off = getBucketOffset(key);
+ int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] != null && compare(keys[off], key)) {
+ keys[off] = null;
+ V previous = values[off];
+ values[off] = null;
+ size--;
+ return previous;
+ }
+ }
+ return null;
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public void putAll(Map<? extends K, ? extends V> m) {
+ for (K key : m.keySet()) {
+ put(key, m.get(key));
+ }
+ }
+
+ public boolean isEmpty() {
+ return size == 0;
+ }
+
+ public boolean containsKey(Object key) {
+ return get(key) != null;
+ }
+
+ public boolean containsValue(Object value) {
+ for (V v : values) {
+ if (v != null && compare(v, value)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public void clear() {
+ Arrays.fill(keys, null);
+ Arrays.fill(values, null);
+ size = 0;
+ }
+
+ public Set<K> keySet() {
+ Set<K> set = new HashSet<K>();
+ for (K key : keys) {
+ if (key != null) {
+ set.add(key);
+ }
+ }
+ return set;
+ }
+
+ public Collection<V> values() {
+ Collection<V> list = new ArrayList<V>();
+ for (V value : values) {
+ if (value != null) {
+ list.add(value);
+ }
+ }
+ return list;
+ }
+
+ public Set<Entry<K, V>> entrySet() {
+ Set<Entry<K, V>> set = new HashSet<Entry<K, V>>();
+ for (K key : keys) {
+ if (key != null) {
+ set.add(new MapEntry<K, V>(this, key));
+ }
+ }
+ return set;
+ }
+
+ private static final class MapEntry<K, V> implements Map.Entry<K, V> {
+ private final Map<K, V> map;
+ private final K key;
+
+ public MapEntry(Map<K, V> map, K key) {
+ this.map = map;
+ this.key = key;
+ }
+
+ public K getKey() {
+ return key;
+ }
+
+ public V getValue() {
+ return map.get(key);
+ }
+
+ public V setValue(V value) {
+ return map.put(key, value);
+ }
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // remember the number of bits
+ out.writeInt(this.bits);
+ // remember the total number of entries
+ out.writeInt(this.size);
+ // write all entries
+ for (int x = 0; x < this.keys.length; x++) {
+ if (keys[x] != null) {
+ out.writeObject(keys[x]);
+ out.writeObject(values[x]);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // resize to old bit size
+ int bitSize = in.readInt();
+ if (bitSize != bits) {
+ resize(bitSize);
+ }
+ // read all entries
+ int size = in.readInt();
+ for (int x = 0; x < size; x++) {
+ this.put((K) in.readObject(), (V) in.readObject());
+ }
+ }
+
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName() + ' ' + this.size;
+ }
+
+ @SuppressWarnings("unchecked")
+ private void resize(int bits) {
+ this.bits = bits;
+ this.sweepbits = bits / 4;
+ this.sweep = MathUtils.powerOf(2, sweepbits) * 4;
+ this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits;
+
+ // remember old values so we can recreate the entries
+ K[] existingKeys = this.keys;
+ V[] existingValues = this.values;
+
+ // create the arrays
+ this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep];
+ this.keys = (K[]) new Object[values.length];
+ this.size = 0;
+
+ // re-add the previous entries if resizing
+ if (existingKeys != null) {
+ for (int x = 0; x < existingKeys.length; x++) {
+ if (existingKeys[x] != null) {
+ put(existingKeys[x], existingValues[x]);
+ }
+ }
+ }
+ }
+
+ private int getBucketOffset(Object key) {
+ return (key.hashCode() << this.sweepbits) & this.sweepmask;
+ }
+
+ private static boolean compare(final Object v1, final Object v2) {
+ return v1 == v2 || v1.equals(v2);
+ }
+
+ public IMapIterator<K, V> entries() {
+ return new MapIterator();
+ }
+
+ private final class MapIterator implements IMapIterator<K, V> {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < keys.length && keys[index] == null) {
+ index++;
+ }
+ return index;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return nextEntry < keys.length;
+ }
+
+ @Override
+ public int next() {
+ free(lastEntry);
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ @Override
+ public K getKey() {
+ return keys[lastEntry];
+ }
+
+ @Override
+ public V getValue() {
+ return values[lastEntry];
+ }
+
+ @Override
+ public <T extends Copyable<V>> void getValue(T probe) {
+ probe.copyFrom(getValue());
+ }
+
+ private void free(int index) {
+ if (index >= 0) {
+ keys[index] = null;
+ values[index] = null;
+ }
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
new file mode 100644
index 0000000..7fec9b0
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+import java.util.HashMap;
+
+import javax.annotation.Nonnull;
+
+/**
+ * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
+ */
+public final class OpenHashTable<K, V> implements Externalizable {
+
+ public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ protected/* final */float _loadFactor;
+ protected/* final */float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+
+ protected K[] _keys;
+ protected V[] _values;
+ protected byte[] _states;
+
+ public OpenHashTable() {} // for Externalizable
+
+ public OpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
+ }
+
+ @SuppressWarnings("unchecked")
+ public OpenHashTable(int size, float loadFactor, float growFactor) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = Primes.findLeastPrimeNumber(size);
+ this._keys = (K[]) new Object[actualSize];
+ this._values = (V[]) new Object[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = Math.round(actualSize * _loadFactor);
+ }
+
+ public OpenHashTable(@Nonnull K[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) {
+ this._used = used;
+ this._threshold = keys.length;
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public Object[] getKeys() {
+ return _keys;
+ }
+
+ public Object[] getValues() {
+ return _values;
+ }
+
+ public byte[] getStates() {
+ return _states;
+ }
+
+ public boolean containsKey(final K key) {
+ return findKey(key) >= 0;
+ }
+
+ public V get(final K key) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return null;
+ }
+ return _values[i];
+ }
+
+ public V put(final K key, final V value) {
+ int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ K[] keys = _keys;
+ V[] values = _values;
+ byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {
+ if (equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return null;
+ }
+
+ private static boolean equals(final Object k1, final Object k2) {
+ return k1 == k2 || k1.equals(k2);
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(int index, K key) {
+ byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && equals(_keys[index], key)) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(int index) {
+ if ((_used + 1) >= _threshold) {// filled enough
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final K key) {
+ K[] keys = _keys;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ return keyIdx;
+ }
+ // try second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public V remove(final K key) {
+ K[] keys = _keys;
+ V[] values = _values;
+ byte[] states = _states;
+ int keyLength = keys.length;
+
+ int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return null;
+ }
+ if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) {
+ V old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return null;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator<K, V> entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ final StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ final IMapIterator<K, V> i = entries();
+ while (i.next() != -1) {
+ String key = i.getKey().toString();
+ buf.append(key);
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void rehash(int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final K[] newkeys = (K[]) new Object[newCapacity];
+ final V[] newValues = (V[]) new Object[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ K k = _keys[i];
+ V v = _values[i];
+ int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newStates[keyIdx] = FULL;
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(final Object key) {
+ int hash = key.hashCode();
+ return hash & 0x7fffffff;
+ }
+
+ private final class MapIterator implements IMapIterator<K, V> {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = nextEntry;
+ this.nextEntry = nextEntry(nextEntry + 1);
+ return curEntry;
+ }
+
+ public K getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public V getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+
+ @Override
+ public <T extends Copyable<V>> void getValue(T probe) {
+ probe.copyFrom(getValue());
+ }
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeFloat(_loadFactor);
+ out.writeFloat(_growFactor);
+ out.writeInt(_used);
+
+ final int size = _keys.length;
+ out.writeInt(size);
+
+ for (int i = 0; i < size; i++) {
+ out.writeObject(_keys[i]);
+ out.writeObject(_values[i]);
+ out.writeByte(_states[i]);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._loadFactor = in.readFloat();
+ this._growFactor = in.readFloat();
+ this._used = in.readInt();
+
+ final int size = in.readInt();
+ final Object[] keys = new Object[size];
+ final Object[] values = new Object[size];
+ final byte[] states = new byte[size];
+ for (int i = 0; i < size; i++) {
+ keys[i] = in.readObject();
+ values[i] = in.readObject();
+ states[i] = in.readByte();
+ }
+ this._threshold = size;
+ this._keys = (K[]) keys;
+ this._values = (V[]) values;
+ this._states = states;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
new file mode 100644
index 0000000..06b6a15
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.sets;
+
+import hivemall.utils.lang.ArrayUtils;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+public final class IntArraySet implements IntSet {
+
+ @Nonnull
+ private int[] mKeys;
+ private int mSize;
+
+ public IntArraySet() {
+ this(0);
+ }
+
+ public IntArraySet(int initSize) {
+ this.mKeys = new int[initSize];
+ this.mSize = 0;
+ }
+
+ @Override
+ public boolean add(final int k) {
+ final int i = Arrays.binarySearch(mKeys, 0, mSize, k);
+ if (i >= 0) {
+ return false;
+ }
+ mKeys = ArrayUtils.insert(mKeys, mSize, ~i, k);
+ mSize++;
+ return true;
+ }
+
+ @Override
+ public boolean remove(final int k) {
+ final int i = Arrays.binarySearch(mKeys, 0, mSize, k);
+ if (i < 0) {
+ return false;
+ }
+ System.arraycopy(mKeys, i + 1, mKeys, i, mSize - (i + 1));
+ mSize--;
+ return true;
+ }
+
+ @Override
+ public boolean contains(final int k) {
+ return Arrays.binarySearch(mKeys, 0, mSize, k) >= 0;
+ }
+
+ @Override
+ public int size() {
+ return mSize;
+ }
+
+ @Override
+ public void clear() {
+ this.mSize = 0;
+ }
+
+ @Override
+ public int[] toArray(final boolean copy) {
+ if (copy == false && mKeys.length == mSize) {
+ return mKeys;
+ }
+
+ return Arrays.copyOf(mKeys, mSize);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntSet.java b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
new file mode 100644
index 0000000..398955c
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.sets;
+
+import javax.annotation.Nonnull;
+
+public interface IntSet {
+
+ public boolean add(int k);
+
+ public boolean remove(int k);
+
+ public boolean contains(int k);
+
+ public int size();
+
+ public void clear();
+
+ @Nonnull
+ public int[] toArray(boolean copy);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5423c9d..b3a2de1 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -56,6 +56,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
+import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
@@ -256,6 +258,10 @@ public final class HiveUtils {
&& isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
}
+ public static boolean isConstString(@Nonnull final ObjectInspector oi) {
+ return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi);
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
@@ -308,20 +314,43 @@ public final class HiveUtils {
}
}
- public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) {
+ public static boolean isIntTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
+ return false;
+ }
+ return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.INT;
+ }
+
+ public static boolean isFloatingPointTypeInfo(@Nonnull TypeInfo typeInfo) {
if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
return false;
}
switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) {
- case STRING:
+ case DOUBLE:
+ case FLOAT:
return true;
default:
return false;
}
}
- public static boolean isConstString(@Nonnull final ObjectInspector oi) {
- return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi);
+ public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) {
+ return false;
+ }
+ return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.STRING;
+ }
+
+ public static boolean isListTypeInfo(@Nonnull TypeInfo typeInfo) {
+ return typeInfo.getCategory() == Category.LIST;
+ }
+
+ public static boolean isFloatingPointListTypeInfo(@Nonnull TypeInfo typeInfo) {
+ if (typeInfo.getCategory() != Category.LIST) {
+ return false;
+ }
+ TypeInfo elemTypeInfo = ((ListTypeInfo) typeInfo).getListElementTypeInfo();
+ return isFloatingPointTypeInfo(elemTypeInfo);
}
@Nonnull
@@ -387,6 +416,38 @@ public final class HiveUtils {
return ary;
}
+ @Nullable
+ public static double[] getConstDoubleArray(@Nonnull final ObjectInspector oi)
+ throws UDFArgumentException {
+ if (!ObjectInspectorUtils.isConstantObjectInspector(oi)) {
+ throw new UDFArgumentException("argument must be a constant value: "
+ + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+ }
+ ConstantObjectInspector constOI = (ConstantObjectInspector) oi;
+ if (constOI.getCategory() != Category.LIST) {
+ throw new UDFArgumentException("argument must be an array: "
+ + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+ }
+ StandardConstantListObjectInspector listOI = (StandardConstantListObjectInspector) constOI;
+ PrimitiveObjectInspector elemOI = HiveUtils.asDoubleCompatibleOI(listOI.getListElementObjectInspector());
+
+ final List<?> lst = listOI.getWritableConstantValue();
+ if (lst == null) {
+ return null;
+ }
+ final int size = lst.size();
+ final double[] ary = new double[size];
+ for (int i = 0; i < size; i++) {
+ Object o = lst.get(i);
+ if (o == null) {
+ ary[i] = Double.NaN;
+ } else {
+ ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
+ }
+ }
+ return ary;
+ }
+
public static String getConstString(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
if (!isStringOI(oi)) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 24ed7fc..e8e337d 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -170,6 +170,24 @@ public final class ArrayUtils {
arr[j] = tmp;
}
+ public static void swap(@Nonnull final long[] arr, final int i, final int j) {
+ long tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
+ public static void swap(@Nonnull final float[] arr, final int i, final int j) {
+ float tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
+ public static void swap(@Nonnull final double[] arr, final int i, final int j) {
+ double tmp = arr[i];
+ arr[i] = arr[j];
+ arr[j] = tmp;
+ }
+
@Nullable
public static Object[] subarray(@Nullable final Object[] array, int startIndexInclusive,
int endIndexExclusive) {
@@ -198,7 +216,7 @@ public final class ArrayUtils {
}
}
- public static int indexOf(@Nonnull final int[] array, final int valueToFind,
+ public static int indexOf(@Nullable final int[] array, final int valueToFind,
final int startIndex, final int endIndex) {
if (array == null) {
return INDEX_NOT_FOUND;
@@ -215,6 +233,36 @@ public final class ArrayUtils {
return INDEX_NOT_FOUND;
}
+ public static int lastIndexOf(@Nullable final int[] array, final int valueToFind, int startIndex) {
+ if (array == null) {
+ return INDEX_NOT_FOUND;
+ }
+ return lastIndexOf(array, valueToFind, startIndex, array.length);
+ }
+
+ /**
+ * @param startIndex inclusive start index
+ * @param endIndex exclusive end index
+ */
+ public static int lastIndexOf(@Nullable final int[] array, final int valueToFind,
+ int startIndex, int endIndex) {
+ if (array == null) {
+ return INDEX_NOT_FOUND;
+ }
+ if (startIndex < 0) {
+ throw new IllegalArgumentException("startIndex out of bound: " + startIndex);
+ }
+ if (endIndex >= array.length) {
+ throw new IllegalArgumentException("endIndex out of bound: " + endIndex);
+ }
+ for (int i = endIndex - 1; i >= startIndex; i--) {
+ if (valueToFind == array[i]) {
+ return i;
+ }
+ }
+ return INDEX_NOT_FOUND;
+ }
+
@Nonnull
public static byte[] copyOf(@Nonnull final byte[] original, final int newLength) {
final byte[] copy = new byte[newLength];
@@ -249,6 +297,17 @@ public final class ArrayUtils {
}
@Nonnull
+ public static float[] append(@Nonnull float[] array, final int currentSize, final float element) {
+ if (currentSize + 1 > array.length) {
+ float[] newArray = new float[currentSize * 2];
+ System.arraycopy(array, 0, newArray, 0, currentSize);
+ array = newArray;
+ }
+ array[currentSize] = element;
+ return array;
+ }
+
+ @Nonnull
public static double[] append(@Nonnull double[] array, final int currentSize,
final double element) {
if (currentSize + 1 > array.length) {
@@ -268,7 +327,22 @@ public final class ArrayUtils {
array[index] = element;
return array;
}
- int[] newArray = new int[currentSize * 2];
+ final int[] newArray = new int[currentSize * 2];
+ System.arraycopy(array, 0, newArray, 0, index);
+ newArray[index] = element;
+ System.arraycopy(array, index, newArray, index + 1, array.length - index);
+ return newArray;
+ }
+
+ @Nonnull
+ public static float[] insert(@Nonnull final float[] array, final int currentSize,
+ final int index, final float element) {
+ if (currentSize + 1 <= array.length) {
+ System.arraycopy(array, index, array, index + 1, currentSize - index);
+ array[index] = element;
+ return array;
+ }
+ final float[] newArray = new float[currentSize * 2];
System.arraycopy(array, 0, newArray, 0, index);
newArray[index] = element;
System.arraycopy(array, index, newArray, index + 1, array.length - index);
@@ -283,7 +357,7 @@ public final class ArrayUtils {
array[index] = element;
return array;
}
- double[] newArray = new double[currentSize * 2];
+ final double[] newArray = new double[currentSize * 2];
System.arraycopy(array, 0, newArray, 0, index);
newArray[index] = element;
System.arraycopy(array, index, newArray, index + 1, array.length - index);
@@ -314,4 +388,331 @@ public final class ArrayUtils {
return true;
}
+ public static void copy(@Nonnull final float[] src, @Nonnull final double[] dst) {
+ final int size = Math.min(src.length, dst.length);
+ for (int i = 0; i < size; i++) {
+ dst[i] = src[i];
+ }
+ }
+
+ public static void sort(final long[] arr, final double[] brr) {
+ sort(arr, brr, arr.length);
+ }
+
+ public static void sort(final long[] arr, final double[] brr, final int n) {
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ long a;
+ double b;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final double[] crr) {
+ sort(arr, brr, crr, arr.length);
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final double[] crr, final int n) {
+ Preconditions.checkArgument(arr.length >= n);
+ Preconditions.checkArgument(brr.length >= n);
+ Preconditions.checkArgument(crr.length >= n);
+
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ int a, b;
+ double c;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ c = crr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ crr[i + 1] = crr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ crr[i + 1] = c;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ swap(crr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ swap(crr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ swap(crr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ swap(crr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ c = crr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ swap(crr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ crr[l + 1] = crr[j];
+ crr[j] = c;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final float[] crr) {
+ sort(arr, brr, crr, arr.length);
+ }
+
+ public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr,
+ @Nonnull final float[] crr, final int n) {
+ Preconditions.checkArgument(arr.length >= n);
+ Preconditions.checkArgument(brr.length >= n);
+ Preconditions.checkArgument(crr.length >= n);
+
+ final int NSTACK = 64;
+ final int M = 7;
+ final int[] istack = new int[NSTACK];
+
+ int jstack = -1;
+ int l = 0;
+ int ir = n - 1;
+
+ int i, j, k;
+ int a, b;
+ float c;
+ for (;;) {
+ if (ir - l < M) {
+ for (j = l + 1; j <= ir; j++) {
+ a = arr[j];
+ b = brr[j];
+ c = crr[j];
+ for (i = j - 1; i >= l; i--) {
+ if (arr[i] <= a) {
+ break;
+ }
+ arr[i + 1] = arr[i];
+ brr[i + 1] = brr[i];
+ crr[i + 1] = crr[i];
+ }
+ arr[i + 1] = a;
+ brr[i + 1] = b;
+ crr[i + 1] = c;
+ }
+ if (jstack < 0) {
+ break;
+ }
+ ir = istack[jstack--];
+ l = istack[jstack--];
+ } else {
+ k = (l + ir) >> 1;
+ swap(arr, k, l + 1);
+ swap(brr, k, l + 1);
+ swap(crr, k, l + 1);
+ if (arr[l] > arr[ir]) {
+ swap(arr, l, ir);
+ swap(brr, l, ir);
+ swap(crr, l, ir);
+ }
+ if (arr[l + 1] > arr[ir]) {
+ swap(arr, l + 1, ir);
+ swap(brr, l + 1, ir);
+ swap(crr, l + 1, ir);
+ }
+ if (arr[l] > arr[l + 1]) {
+ swap(arr, l, l + 1);
+ swap(brr, l, l + 1);
+ swap(crr, l, l + 1);
+ }
+ i = l + 1;
+ j = ir;
+ a = arr[l + 1];
+ b = brr[l + 1];
+ c = crr[l + 1];
+ for (;;) {
+ do {
+ i++;
+ } while (arr[i] < a);
+ do {
+ j--;
+ } while (arr[j] > a);
+ if (j < i) {
+ break;
+ }
+ swap(arr, i, j);
+ swap(brr, i, j);
+ swap(crr, i, j);
+ }
+ arr[l + 1] = arr[j];
+ arr[j] = a;
+ brr[l + 1] = brr[j];
+ brr[j] = b;
+ crr[l + 1] = crr[j];
+ crr[j] = c;
+ jstack += 2;
+
+ if (jstack >= NSTACK) {
+ throw new IllegalStateException("NSTACK too small in sort.");
+ }
+
+ if (ir - i + 1 >= j - l) {
+ istack[jstack] = ir;
+ istack[jstack - 1] = i;
+ ir = j - 1;
+ } else {
+ istack[jstack] = j - 1;
+ istack[jstack - 1] = l;
+ l = i;
+ }
+ }
+ }
+ }
+
+ public static int count(@Nonnull final int[] values, final int valueToFind) {
+ int cnt = 0;
+ for (int i = 0; i < values.length; i++) {
+ if (values[i] == valueToFind) {
+ cnt++;
+ }
+ }
+ return cnt;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index 8f018f0..31cd8a8 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -18,6 +18,8 @@
*/
package hivemall.utils.lang;
+import javax.annotation.Nonnull;
+
public final class Primitives {
public static final int INT_BYTES = Integer.SIZE / Byte.SIZE;
public static final int DOUBLE_BYTES = Double.SIZE / Byte.SIZE;
@@ -99,4 +101,30 @@ public final class Primitives {
return result;
}
+ public static long toLong(final int high, final int low) {
+ return ((long) high << 32) | ((long) low & 0xffffffffL);
+ }
+
+ public static int getHigh(final long key) {
+ return (int) (key >>> 32) & 0xffffffff;
+ }
+
+ public static int getLow(final long key) {
+ return (int) key & 0xffffffff;
+ }
+
+ @Nonnull
+ public static byte[] toBytes(long l) {
+ final byte[] retVal = new byte[8];
+ for (int i = 0; i < 8; i++) {
+ retVal[i] = (byte) l;
+ l >>= 8;
+ }
+ return retVal;
+ }
+
+ public static int hashCode(final long value) {
+ return (int) (value ^ (value >>> 32));
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 252ccf6..b71d165 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -36,6 +36,7 @@ package hivemall.utils.math;
import java.util.Random;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
public final class MathUtils {
@@ -250,6 +251,9 @@ public final class MathUtils {
}
public static boolean equals(@Nonnull final float value, final float expected, final float delta) {
+ if (Double.isNaN(value)) {
+ return false;
+ }
if (Math.abs(expected - value) > delta) {
return false;
}
@@ -258,19 +262,20 @@ public final class MathUtils {
public static boolean equals(@Nonnull final double value, final double expected,
final double delta) {
+ if (Double.isNaN(value)) {
+ return false;
+ }
if (Math.abs(expected - value) > delta) {
return false;
}
return true;
}
- public static boolean almostEquals(@Nonnull final float value, final float expected,
- final float delta) {
+ public static boolean almostEquals(@Nonnull final float value, final float expected) {
return equals(value, expected, 1E-15f);
}
- public static boolean almostEquals(@Nonnull final double value, final double expected,
- final double delta) {
+ public static boolean almostEquals(@Nonnull final double value, final double expected) {
return equals(value, expected, 1E-15d);
}
@@ -297,4 +302,13 @@ public final class MathUtils {
return 0; // 0 or NaN
}
+ @Nonnull
+ public static int[] permutation(@Nonnegative final int size) {
+ final int[] perm = new int[size];
+ for (int i = 0; i < size; i++) {
+ perm[i] = i;
+ }
+ return perm;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MatrixUtils.java b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
index 66d6e8c..a0e5fc7 100644
--- a/core/src/main/java/hivemall/utils/math/MatrixUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
@@ -18,7 +18,7 @@
*/
package hivemall.utils.math;
-import hivemall.utils.collections.DoubleArrayList;
+import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.lang.Preconditions;
import java.util.Arrays;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
new file mode 100644
index 0000000..f86a788
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.sampling;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+
+/**
+ * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n items.
+ *
+ * @link http://en.wikipedia.org/wiki/Reservoir_sampling
+ * @link http://portal.acm.org/citation.cfm?id=3165
+ */
+public final class IntReservoirSampler {
+
+ private final int[] samples;
+ private final int numSamples;
+ private int position;
+
+ private final Random rand;
+
+ public IntReservoirSampler(int sampleSize) {
+ if (sampleSize <= 0) {
+ throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize);
+ }
+ this.samples = new int[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public IntReservoirSampler(int sampleSize, long seed) {
+ this.samples = new int[sampleSize];
+ this.numSamples = sampleSize;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public IntReservoirSampler(int[] samples) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random();
+ }
+
+ public IntReservoirSampler(int[] samples, long seed) {
+ this.samples = samples;
+ this.numSamples = samples.length;
+ this.position = 0;
+ this.rand = new Random(seed);
+ }
+
+ public int size() {
+ return position;
+ }
+
+ @Nonnull
+ public int[] getSample() {
+ if (position >= numSamples) {
+ return samples;
+ }
+ return Arrays.copyOf(samples, position);
+ }
+
+ public void add(final int item) {
+ if (position < numSamples) {// reservoir not yet full, just append
+ samples[position] = item;
+ } else {// find a item to replace
+ int replaceIndex = rand.nextInt(position + 1);
+ if (replaceIndex < numSamples) {
+ samples[replaceIndex] = item;
+ }
+ }
+ position++;
+ }
+
+ public void clear() {
+ Arrays.fill(samples, 0);
+ this.position = 0;
+ }
+}