You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/11 05:36:45 UTC
[2/4] incubator-hivemall git commit: Close #105: [HIVEMALL-24-2] Make
ffm_predict function more scalable by creating its UDAF implementation
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java
new file mode 100644
index 0000000..ffa80d0
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+//
+// Copyright (C) 2010 catchpole.net
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.hashing.HashUtils;
+import hivemall.utils.math.MathUtils;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * A space efficient open-addressing HashMap implementation with integer keys and long values.
+ *
+ * Unlike {@link Int2LongOpenHashTable}, it maintains single arrays for keys and object references.
+ *
+ * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those
+ * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking
+ * operation.
+ *
+ * The index into the arrays is determined by masking a portion of the key and shifting it to
+ * provide a series of small buckets within the array. To insert an entry the a sweep is searched
+ * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need
+ * to rehash. If no key space is found within a sweep, the table size is doubled.
+ *
+ * While performance is high, the slowest situation is where lookup occurs for entries that do not
+ * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient
+ * than other open-addressing HashMap implementations as in fastutil.
+ */
+@NotThreadSafe
+public final class Int2LongOpenHashMap {
+
+ // special treatment for key=0
+ private boolean hasKey0 = false;
+ private long value0 = 0L;
+
+ private int[] keys;
+ private long[] values;
+
+ // total number of entries in this table
+ private int size;
+ // number of bits for the value table (eg. 8 bits = 256 entries)
+ private int bits;
+ // the number of bits in each sweep zone.
+ private int sweepbits;
+ // the size of a sweep (2 to the power of sweepbits)
+ private int sweep;
+ // the sweepmask used to create sweep zone offsets
+ private int sweepmask;
+
+ public Int2LongOpenHashMap(int size) {
+ resize(MathUtils.bitsRequired(size < 256 ? 256 : size));
+ }
+
+ public long put(final int key, final long value) {
+ if (key == 0) {
+ if (!hasKey0) {
+ this.hasKey0 = true;
+ size++;
+ }
+ long old = value0;
+ this.value0 = value;
+ return old;
+ }
+
+ for (;;) {
+ int off = getBucketOffset(key);
+ final int end = off + sweep;
+ for (; off < end; off++) {
+ final int searchKey = keys[off];
+ if (searchKey == 0) { // insert
+ keys[off] = key;
+ size++;
+ long previous = values[off];
+ values[off] = value;
+ return previous;
+ } else if (searchKey == key) {// replace
+ long previous = values[off];
+ values[off] = value;
+ return previous;
+ }
+ }
+ resize(this.bits + 1);
+ }
+ }
+
+ public long putIfAbsent(final int key, final long value) {
+ if (key == 0) {
+ if (hasKey0) {
+ return value0;
+ }
+ this.hasKey0 = true;
+ long old = value0;
+ this.value0 = value;
+ size++;
+ return old;
+ }
+
+ for (;;) {
+ int off = getBucketOffset(key);
+ final int end = off + sweep;
+ for (; off < end; off++) {
+ final int searchKey = keys[off];
+ if (searchKey == 0) { // insert
+ keys[off] = key;
+ size++;
+ long previous = values[off];
+ values[off] = value;
+ return previous;
+ } else if (searchKey == key) {// replace
+ return values[off];
+ }
+ }
+ resize(this.bits + 1);
+ }
+ }
+
+ public long get(final int key) {
+ return get(key, 0L);
+ }
+
+ public long get(final int key, final long defaultValue) {
+ if (key == 0) {
+ return hasKey0 ? value0 : defaultValue;
+ }
+
+ int off = getBucketOffset(key);
+ final int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] == key) {
+ return values[off];
+ }
+ }
+ return defaultValue;
+ }
+
+ public long remove(final int key, final long defaultValue) {
+ if (key == 0) {
+ if (hasKey0) {
+ this.hasKey0 = false;
+ long old = value0;
+ this.value0 = 0L;
+ size--;
+ return old;
+ } else {
+ return defaultValue;
+ }
+ }
+
+ int off = getBucketOffset(key);
+ final int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] == key) {
+ keys[off] = 0;
+ long previous = values[off];
+ values[off] = 0L;
+ size--;
+ return previous;
+ }
+ }
+ return defaultValue;
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public boolean isEmpty() {
+ return size == 0;
+ }
+
+ public boolean containsKey(final int key) {
+ if (key == 0) {
+ return hasKey0;
+ }
+
+ int off = getBucketOffset(key);
+ final int end = sweep + off;
+ for (; off < end; off++) {
+ if (keys[off] == key) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public void clear() {
+ this.hasKey0 = false;
+ this.value0 = 0L;
+ Arrays.fill(keys, 0);
+ Arrays.fill(values, 0L);
+ this.size = 0;
+ }
+
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName() + ' ' + size;
+ }
+
+ private void resize(final int bits) {
+ this.bits = bits;
+ this.sweepbits = bits / 4;
+ this.sweep = MathUtils.powerOf(2, sweepbits) * 4;
+ this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits;
+
+ // remember old values so we can recreate the entries
+ final int[] existingKeys = this.keys;
+ final long[] existingValues = this.values;
+
+ // create the arrays
+ this.values = new long[MathUtils.powerOf(2, bits) + sweep];
+ this.keys = new int[values.length];
+ this.size = hasKey0 ? 1 : 0;
+
+ // re-add the previous entries if resizing
+ if (existingKeys != null) {
+ for (int i = 0; i < existingKeys.length; i++) {
+ final int k = existingKeys[i];
+ if (k != 0) {
+ put(k, existingValues[i]);
+ }
+ }
+ }
+ }
+
+ private int getBucketOffset(final int key) {
+ return (HashUtils.fnv1a(key) << sweepbits) & sweepmask;
+ }
+
+ @Nonnull
+ public MapIterator entries() {
+ return new MapIterator();
+ }
+
+ public final class MapIterator {
+
+ int nextEntry;
+ int lastEntry = -2;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(-1);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ if (index == -1) {
+ if (hasKey0) {
+ return -1;
+ } else {
+ index = 0;
+ }
+ }
+ while (index < keys.length && keys[index] == 0) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < keys.length;
+ }
+
+ public boolean next() {
+ free(lastEntry);
+ if (!hasNext()) {
+ return false;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return true;
+ }
+
+ public int getKey() {
+ if (lastEntry >= 0 && lastEntry < keys.length) {
+ return keys[lastEntry];
+ } else if (lastEntry == -1) {
+ return 0;
+ } else {
+ throw new IllegalStateException(
+ "next() should be called before getKey(). lastEntry=" + lastEntry
+ + ", keys.length=" + keys.length);
+ }
+ }
+
+ public long getValue() {
+ if (lastEntry >= 0 && lastEntry < keys.length) {
+ return values[lastEntry];
+ } else if (lastEntry == -1) {
+ return value0;
+ } else {
+ throw new IllegalStateException(
+ "next() should be called before getKey(). lastEntry=" + lastEntry
+ + ", keys.length=" + keys.length);
+ }
+ }
+
+ private void free(int index) {
+ if (index >= 0) {
+ if (index >= keys.length) {
+ throw new IllegalStateException("index=" + index + ", keys.length="
+ + keys.length);
+ }
+ keys[index] = 0;
+ values[index] = 0L;
+ } else if (index == -1) {
+ hasKey0 = false;
+ value0 = 0L;
+ }
+ // index may be -2
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
index 68eb42f..22acdb4 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java
@@ -33,7 +33,12 @@ import java.util.Arrays;
import javax.annotation.Nonnull;
/**
- * An open-addressing hash table with double hashing
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
*
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
@@ -44,7 +49,7 @@ public class Int2LongOpenHashTable implements Externalizable {
protected static final byte REMOVED = 2;
public static final int DEFAULT_SIZE = 65536;
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_LOAD_FACTOR = 0.75f;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
@@ -123,23 +128,23 @@ public class Int2LongOpenHashTable implements Externalizable {
return _states;
}
- public boolean containsKey(int key) {
+ public boolean containsKey(final int key) {
return findKey(key) >= 0;
}
/**
* @return -1.f if not found
*/
- public long get(int key) {
- int i = findKey(key);
+ public long get(final int key) {
+ final int i = findKey(key);
if (i < 0) {
return defaultReturnValue;
}
return _values[i];
}
- public long put(int key, long value) {
- int hash = keyHash(key);
+ public long put(final int key, final long value) {
+ final int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
@@ -149,9 +154,9 @@ public class Int2LongOpenHashTable implements Externalizable {
keyIdx = hash % keyLength;
}
- int[] keys = _keys;
- long[] values = _values;
- byte[] states = _states;
+ final int[] keys = _keys;
+ final long[] values = _values;
+ final byte[] states = _states;
if (states[keyIdx] == FULL) {// double hashing
if (keys[keyIdx] == key) {
@@ -160,7 +165,7 @@ public class Int2LongOpenHashTable implements Externalizable {
return old;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -184,8 +189,8 @@ public class Int2LongOpenHashTable implements Externalizable {
}
/** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
if (stat == FREE) {
return true;
}
@@ -196,7 +201,7 @@ public class Int2LongOpenHashTable implements Externalizable {
}
/** @return expanded or not */
- protected boolean preAddEntry(int index) {
+ protected boolean preAddEntry(final int index) {
if ((_used + 1) >= _threshold) {// too filled
int newCapacity = Math.round(_keys.length * _growFactor);
ensureCapacity(newCapacity);
@@ -205,19 +210,19 @@ public class Int2LongOpenHashTable implements Externalizable {
return false;
}
- protected int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
return keyIdx;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -234,13 +239,13 @@ public class Int2LongOpenHashTable implements Externalizable {
return -1;
}
- public long remove(int key) {
- int[] keys = _keys;
- long[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
+ public long remove(final int key) {
+ final int[] keys = _keys;
+ final long[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
@@ -250,7 +255,7 @@ public class Int2LongOpenHashTable implements Externalizable {
return old;
}
// second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -283,21 +288,22 @@ public class Int2LongOpenHashTable implements Externalizable {
this._used = 0;
}
- public IMapIterator entries() {
+ @Nonnull
+ public MapIterator entries() {
return new MapIterator();
}
@Override
public String toString() {
int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
+ final StringBuilder buf = new StringBuilder(len);
buf.append('{');
- IMapIterator i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
+ final MapIterator itor = entries();
+ while (itor.next() != -1) {
+ buf.append(itor.getKey());
buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
+ buf.append(itor.getValue());
+ if (itor.hasNext()) {
buf.append(',');
}
}
@@ -305,30 +311,30 @@ public class Int2LongOpenHashTable implements Externalizable {
return buf.toString();
}
- protected void ensureCapacity(int newCapacity) {
+ protected void ensureCapacity(final int newCapacity) {
int prime = Primes.findLeastPrimeNumber(newCapacity);
rehash(prime);
this._threshold = Math.round(prime * _loadFactor);
}
- private void rehash(int newCapacity) {
+ private void rehash(final int newCapacity) {
int oldCapacity = _keys.length;
if (newCapacity <= oldCapacity) {
throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
}
- int[] newkeys = new int[newCapacity];
- long[] newValues = new long[newCapacity];
- byte[] newStates = new byte[newCapacity];
+ final int[] newkeys = new int[newCapacity];
+ final long[] newValues = new long[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
int used = 0;
for (int i = 0; i < oldCapacity; i++) {
if (_states[i] == FULL) {
used++;
- int k = _keys[i];
- long v = _values[i];
- int hash = keyHash(k);
+ final int k = _keys[i];
+ final long v = _values[i];
+ final int hash = keyHash(k);
int keyIdx = hash % newCapacity;
if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
+ final int decr = 1 + (hash % (newCapacity - 2));
while (newStates[keyIdx] != FREE) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -347,7 +353,7 @@ public class Int2LongOpenHashTable implements Externalizable {
this._used = used;
}
- private static int keyHash(int key) {
+ private static int keyHash(final int key) {
return key & 0x7fffffff;
}
@@ -437,22 +443,7 @@ public class Int2LongOpenHashTable implements Externalizable {
}
}
- public interface IMapIterator {
-
- public boolean hasNext();
-
- /**
- * @return -1 if not found
- */
- public int next();
-
- public int getKey();
-
- public long getValue();
-
- }
-
- private final class MapIterator implements IMapIterator {
+ public final class MapIterator {
int nextEntry;
int lastEntry = -1;
@@ -473,6 +464,9 @@ public class Int2LongOpenHashTable implements Externalizable {
return nextEntry < _keys.length;
}
+ /**
+ * @return -1 if not found
+ */
public int next() {
if (!hasNext()) {
return -1;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
deleted file mode 100644
index 5ce34a4..0000000
--- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
+++ /dev/null
@@ -1,467 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.utils.collections.maps;
-
-import hivemall.utils.math.Primes;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-/**
- * An open-addressing hash table with double hashing
- *
- * @see http://en.wikipedia.org/wiki/Double_hashing
- */
-public class IntOpenHashMap<V> implements Externalizable {
- private static final long serialVersionUID = -8162355845665353513L;
-
- protected static final byte FREE = 0;
- protected static final byte FULL = 1;
- protected static final byte REMOVED = 2;
-
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
- private static final float DEFAULT_GROW_FACTOR = 2.0f;
-
- protected final transient float _loadFactor;
- protected final transient float _growFactor;
-
- protected int _used = 0;
- protected int _threshold;
-
- protected int[] _keys;
- protected V[] _values;
- protected byte[] _states;
-
- @SuppressWarnings("unchecked")
- protected IntOpenHashMap(int size, float loadFactor, float growFactor, boolean forcePrime) {
- if (size < 1) {
- throw new IllegalArgumentException();
- }
- this._loadFactor = loadFactor;
- this._growFactor = growFactor;
- int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
- this._keys = new int[actualSize];
- this._values = (V[]) new Object[actualSize];
- this._states = new byte[actualSize];
- this._threshold = Math.round(actualSize * _loadFactor);
- }
-
- public IntOpenHashMap(int size, float loadFactor, float growFactor) {
- this(size, loadFactor, growFactor, true);
- }
-
- public IntOpenHashMap(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
- }
-
- public IntOpenHashMap() {// required for serialization
- this._loadFactor = DEFAULT_LOAD_FACTOR;
- this._growFactor = DEFAULT_GROW_FACTOR;
- }
-
- public boolean containsKey(int key) {
- return findKey(key) >= 0;
- }
-
- public final V get(final int key) {
- final int i = findKey(key);
- if (i < 0) {
- return null;
- }
- recordAccess(i);
- return _values[i];
- }
-
- public V put(final int key, final V value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- final boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// double hashing
- if (keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- recordAccess(keyIdx);
- return old;
- }
- // try second hash
- final int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- values[keyIdx] = value;
- recordAccess(keyIdx);
- return old;
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- ++_used;
- postAddEntry(keyIdx);
- return null;
- }
-
- public V putIfAbsent(final int key, final V value) {
- final int hash = keyHash(key);
- int keyLength = _keys.length;
- int keyIdx = hash % keyLength;
-
- final boolean expanded = preAddEntry(keyIdx);
- if (expanded) {
- keyLength = _keys.length;
- keyIdx = hash % keyLength;
- }
-
- final int[] keys = _keys;
- final V[] values = _values;
- final byte[] states = _states;
-
- if (states[keyIdx] == FULL) {// second hashing
- if (keys[keyIdx] == key) {
- return values[keyIdx];
- }
- // try second hash
- final int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- break;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return values[keyIdx];
- }
- }
- }
- keys[keyIdx] = key;
- values[keyIdx] = value;
- states[keyIdx] = FULL;
- _used++;
- postAddEntry(keyIdx);
- return null;
- }
-
- /** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
- if (stat == FREE) {
- return true;
- }
- if (stat == REMOVED && _keys[index] == key) {
- return true;
- }
- return false;
- }
-
- /** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// too filled
- int newCapacity = Math.round(_keys.length * _growFactor);
- ensureCapacity(newCapacity);
- return true;
- }
- return false;
- }
-
- protected void postAddEntry(int index) {}
-
- private int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- // try second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (isFree(keyIdx, key)) {
- return -1;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- return keyIdx;
- }
- }
- }
- return -1;
- }
-
- public V remove(int key) {
- int[] keys = _keys;
- V[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
-
- int hash = keyHash(key);
- int keyIdx = hash % keyLength;
- if (states[keyIdx] != FREE) {
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- recordRemoval(keyIdx);
- return old;
- }
- // second hash
- int decr = 1 + (hash % (keyLength - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keyLength;
- }
- if (states[keyIdx] == FREE) {
- return null;
- }
- if (states[keyIdx] == FULL && keys[keyIdx] == key) {
- V old = values[keyIdx];
- states[keyIdx] = REMOVED;
- --_used;
- recordRemoval(keyIdx);
- return old;
- }
- }
- }
- return null;
- }
-
- public int size() {
- return _used;
- }
-
- public void clear() {
- Arrays.fill(_states, FREE);
- this._used = 0;
- }
-
- @SuppressWarnings("unchecked")
- public IMapIterator<V> entries() {
- return new MapIterator();
- }
-
- @Override
- public String toString() {
- int len = size() * 10 + 2;
- StringBuilder buf = new StringBuilder(len);
- buf.append('{');
- IMapIterator<V> i = entries();
- while (i.next() != -1) {
- buf.append(i.getKey());
- buf.append('=');
- buf.append(i.getValue());
- if (i.hasNext()) {
- buf.append(',');
- }
- }
- buf.append('}');
- return buf.toString();
- }
-
- private void ensureCapacity(int newCapacity) {
- int prime = Primes.findLeastPrimeNumber(newCapacity);
- rehash(prime);
- this._threshold = Math.round(prime * _loadFactor);
- }
-
- @SuppressWarnings("unchecked")
- protected void rehash(int newCapacity) {
- int oldCapacity = _keys.length;
- if (newCapacity <= oldCapacity) {
- throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
- }
- final int[] oldKeys = _keys;
- final V[] oldValues = _values;
- final byte[] oldStates = _states;
- int[] newkeys = new int[newCapacity];
- V[] newValues = (V[]) new Object[newCapacity];
- byte[] newStates = new byte[newCapacity];
- int used = 0;
- for (int i = 0; i < oldCapacity; i++) {
- if (oldStates[i] == FULL) {
- used++;
- int k = oldKeys[i];
- V v = oldValues[i];
- int hash = keyHash(k);
- int keyIdx = hash % newCapacity;
- if (newStates[keyIdx] == FULL) {// second hashing
- int decr = 1 + (hash % (newCapacity - 2));
- while (newStates[keyIdx] != FREE) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += newCapacity;
- }
- }
- }
- newkeys[keyIdx] = k;
- newValues[keyIdx] = v;
- newStates[keyIdx] = FULL;
- }
- }
- this._keys = newkeys;
- this._values = newValues;
- this._states = newStates;
- this._used = used;
- }
-
- private static int keyHash(int key) {
- return key & 0x7fffffff;
- }
-
- protected void recordAccess(int idx) {}
-
- protected void recordRemoval(int idx) {}
-
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(_threshold);
- out.writeInt(_used);
-
- out.writeInt(_keys.length);
- IMapIterator<V> i = entries();
- while (i.next() != -1) {
- out.writeInt(i.getKey());
- out.writeObject(i.getValue());
- }
- }
-
- @SuppressWarnings("unchecked")
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- this._threshold = in.readInt();
- this._used = in.readInt();
-
- int keylen = in.readInt();
- int[] keys = new int[keylen];
- V[] values = (V[]) new Object[keylen];
- byte[] states = new byte[keylen];
- for (int i = 0; i < _used; i++) {
- int k = in.readInt();
- V v = (V) in.readObject();
- int hash = keyHash(k);
- int keyIdx = hash % keylen;
- if (states[keyIdx] != FREE) {// second hash
- int decr = 1 + (hash % (keylen - 2));
- for (;;) {
- keyIdx -= decr;
- if (keyIdx < 0) {
- keyIdx += keylen;
- }
- if (states[keyIdx] == FREE) {
- break;
- }
- }
- }
- states[keyIdx] = FULL;
- keys[keyIdx] = k;
- values[keyIdx] = v;
- }
- this._keys = keys;
- this._values = values;
- this._states = states;
- }
-
- public interface IMapIterator<V> {
-
- public boolean hasNext();
-
- public int next();
-
- public int getKey();
-
- public V getValue();
-
- }
-
- @SuppressWarnings("rawtypes")
- private final class MapIterator implements IMapIterator {
-
- int nextEntry;
- int lastEntry = -1;
-
- MapIterator() {
- this.nextEntry = nextEntry(0);
- }
-
- /** find the index of next full entry */
- int nextEntry(int index) {
- while (index < _keys.length && _states[index] != FULL) {
- index++;
- }
- return index;
- }
-
- public boolean hasNext() {
- return nextEntry < _keys.length;
- }
-
- public int next() {
- if (!hasNext()) {
- return -1;
- }
- int curEntry = nextEntry;
- this.lastEntry = curEntry;
- this.nextEntry = nextEntry(curEntry + 1);
- return curEntry;
- }
-
- public int getKey() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _keys[lastEntry];
- }
-
- public V getValue() {
- if (lastEntry == -1) {
- throw new IllegalStateException();
- }
- return _values[lastEntry];
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
index dcb64d1..dbade74 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
@@ -25,54 +25,68 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
-import java.util.HashMap;
import javax.annotation.Nonnull;
/**
- * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
*/
public final class IntOpenHashTable<V> implements Externalizable {
+ private static final long serialVersionUID = -8162355845665353513L;
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_LOAD_FACTOR = 0.75f;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
- public static final byte FREE = 0;
- public static final byte FULL = 1;
- public static final byte REMOVED = 2;
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
protected/* final */float _loadFactor;
protected/* final */float _growFactor;
- protected int _used = 0;
+ protected int _used;
protected int _threshold;
protected int[] _keys;
protected V[] _values;
protected byte[] _states;
- public IntOpenHashTable() {} // for Externalizable
+ public IntOpenHashTable() {} // for Externalizable
public IntOpenHashTable(int size) {
- this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
- @SuppressWarnings("unchecked")
public IntOpenHashTable(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ @SuppressWarnings("unchecked")
+ protected IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) {
if (size < 1) {
throw new IllegalArgumentException();
}
this._loadFactor = loadFactor;
this._growFactor = growFactor;
- int actualSize = Primes.findLeastPrimeNumber(size);
+ this._used = 0;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._threshold = Math.round(actualSize * _loadFactor);
this._keys = new int[actualSize];
this._values = (V[]) new Object[actualSize];
this._states = new byte[actualSize];
- this._threshold = Math.round(actualSize * _loadFactor);
}
public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[] states,
int used) {
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
this._used = used;
this._threshold = keys.length;
this._keys = keys;
@@ -80,14 +94,17 @@ public final class IntOpenHashTable<V> implements Externalizable {
this._states = states;
}
+ @Nonnull
public int[] getKeys() {
return _keys;
}
+ @Nonnull
public Object[] getValues() {
return _values;
}
+ @Nonnull
public byte[] getStates() {
return _states;
}
@@ -109,7 +126,7 @@ public final class IntOpenHashTable<V> implements Externalizable {
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
- boolean expanded = preAddEntry(keyIdx);
+ final boolean expanded = preAddEntry(keyIdx);
if (expanded) {
keyLength = _keys.length;
keyIdx = hash % keyLength;
@@ -119,14 +136,14 @@ public final class IntOpenHashTable<V> implements Externalizable {
final V[] values = _values;
final byte[] states = _states;
- if (states[keyIdx] == FULL) {
+ if (states[keyIdx] == FULL) {// double hashing
if (keys[keyIdx] == key) {
V old = values[keyIdx];
values[keyIdx] = value;
return old;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -149,10 +166,50 @@ public final class IntOpenHashTable<V> implements Externalizable {
return null;
}
+ public V putIfAbsent(final int key, final V value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ final boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final V[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// second hashing
+ if (keys[keyIdx] == key) {
+ return values[keyIdx];
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return values[keyIdx];
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ _used++;
+ return null;
+ }
/** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
if (stat == FREE) {
return true;
}
@@ -163,8 +220,8 @@ public final class IntOpenHashTable<V> implements Externalizable {
}
/** @return expanded or not */
- protected boolean preAddEntry(int index) {
- if ((_used + 1) >= _threshold) {// filled enough
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
int newCapacity = Math.round(_keys.length * _growFactor);
ensureCapacity(newCapacity);
return true;
@@ -172,7 +229,7 @@ public final class IntOpenHashTable<V> implements Externalizable {
return false;
}
- protected int findKey(final int key) {
+ private int findKey(final int key) {
final int[] keys = _keys;
final byte[] states = _states;
final int keyLength = keys.length;
@@ -184,7 +241,7 @@ public final class IntOpenHashTable<V> implements Externalizable {
return keyIdx;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -217,7 +274,7 @@ public final class IntOpenHashTable<V> implements Externalizable {
return old;
}
// second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -255,28 +312,49 @@ public final class IntOpenHashTable<V> implements Externalizable {
this._used = 0;
}
- protected void ensureCapacity(int newCapacity) {
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ final StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ final IMapIterator<V> i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ private void ensureCapacity(final int newCapacity) {
int prime = Primes.findLeastPrimeNumber(newCapacity);
rehash(prime);
this._threshold = Math.round(prime * _loadFactor);
}
@SuppressWarnings("unchecked")
- private void rehash(int newCapacity) {
+ private void rehash(final int newCapacity) {
int oldCapacity = _keys.length;
if (newCapacity <= oldCapacity) {
throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
}
+ final int[] oldKeys = _keys;
+ final V[] oldValues = _values;
+ final byte[] oldStates = _states;
final int[] newkeys = new int[newCapacity];
final V[] newValues = (V[]) new Object[newCapacity];
final byte[] newStates = new byte[newCapacity];
int used = 0;
for (int i = 0; i < oldCapacity; i++) {
- if (_states[i] == FULL) {
+ if (oldStates[i] == FULL) {
used++;
- int k = _keys[i];
- V v = _values[i];
- int hash = keyHash(k);
+ final int k = oldKeys[i];
+ final V v = oldValues[i];
+ final int hash = keyHash(k);
int keyIdx = hash % newCapacity;
if (newStates[keyIdx] == FULL) {// second hashing
int decr = 1 + (hash % (newCapacity - 2));
@@ -287,9 +365,9 @@ public final class IntOpenHashTable<V> implements Externalizable {
}
}
}
- newStates[keyIdx] = FULL;
newkeys[keyIdx] = k;
newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
}
}
this._keys = newkeys;
@@ -303,7 +381,7 @@ public final class IntOpenHashTable<V> implements Externalizable {
}
@Override
- public void writeExternal(ObjectOutput out) throws IOException {
+ public void writeExternal(@Nonnull final ObjectOutput out) throws IOException {
out.writeFloat(_loadFactor);
out.writeFloat(_growFactor);
out.writeInt(_used);
@@ -319,8 +397,8 @@ public final class IntOpenHashTable<V> implements Externalizable {
}
@SuppressWarnings("unchecked")
- @Override
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ public void readExternal(@Nonnull final ObjectInput in) throws IOException,
+ ClassNotFoundException {
this._loadFactor = in.readFloat();
this._growFactor = in.readFloat();
this._used = in.readInt();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
index c758824..b4356ff 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
@@ -27,7 +27,12 @@ import java.io.ObjectOutput;
import java.util.Arrays;
/**
- * An open-addressing hash table with double hashing
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
*
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
@@ -37,7 +42,7 @@ public final class Long2DoubleOpenHashTable implements Externalizable {
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
index 6a7f39f..6b0ab59 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
@@ -27,9 +27,14 @@ import java.io.ObjectOutput;
import java.util.Arrays;
/**
- * An open-addressing hash table with float hashing
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
*
- * @see http://en.wikipedia.org/wiki/float_hashing
+ * @see http://en.wikipedia.org/wiki/Double_hashing
*/
public final class Long2FloatOpenHashTable implements Externalizable {
@@ -37,7 +42,7 @@ public final class Long2FloatOpenHashTable implements Externalizable {
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
index 51b8f12..1ca4c40 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
@@ -27,7 +27,12 @@ import java.io.ObjectOutput;
import java.util.Arrays;
/**
- * An open-addressing hash table with double hashing
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
*
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
@@ -37,7 +42,7 @@ public final class Long2IntOpenHashTable implements Externalizable {
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
index 152447a..f5ee1e6 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java
@@ -48,16 +48,29 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Set;
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
/**
- * An optimized Hashed Map implementation.
- * <p/>
- * <p>
- * This Hashmap does not allow nulls to be used as keys or values.
- * <p/>
- * <p>
+ * A space efficient open-addressing HashMap implementation.
+ *
+ * Unlike {@link OpenHashTable}, it maintains single arrays for keys and object references.
+ *
* It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those
- * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking
+ * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking
* operation.
+ *
+ * The index into the arrays is determined by masking a portion of the key and shifting it to
+ * provide a series of small buckets within the array. To insert an entry the a sweep is searched
+ * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need
+ * to rehash. If no key space is found within a sweep, the table size is doubled.
+ *
+ * While performance is high, the slowest situation is where lookup occurs for entries that do not
+ * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient
+ * than other open-addressing HashMap implementations as in fastutil.
+ *
+ * Note that this HashMap does not allow nulls to be used as keys.
*/
public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
private K[] keys;
@@ -80,21 +93,21 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
resize(MathUtils.bitsRequired(size < 256 ? 256 : size));
}
- public V put(K key, V value) {
+ @Nullable
+ public V put(@CheckForNull final K key, @Nullable final V value) {
if (key == null) {
throw new NullPointerException(this.getClass().getName() + " key");
}
for (;;) {
int off = getBucketOffset(key);
- int end = off + sweep;
+ final int end = off + sweep;
for (; off < end; off++) {
- K searchKey = keys[off];
+ final K searchKey = keys[off];
if (searchKey == null) {
// insert
keys[off] = key;
size++;
-
V previous = values[off];
values[off] = value;
return previous;
@@ -109,9 +122,36 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
}
}
- public V get(Object key) {
+ @Nullable
+ public V putIfAbsent(@CheckForNull final K key, @Nullable final V value) {
+ if (key == null) {
+ throw new NullPointerException(this.getClass().getName() + " key");
+ }
+
+ for (;;) {
+ int off = getBucketOffset(key);
+ final int end = off + sweep;
+ for (; off < end; off++) {
+ final K searchKey = keys[off];
+ if (searchKey == null) {
+ // insert
+ keys[off] = key;
+ size++;
+ V previous = values[off];
+ values[off] = value;
+ return previous;
+ } else if (compare(searchKey, key)) {
+ return values[off];
+ }
+ }
+ resize(this.bits + 1);
+ }
+ }
+
+ @Nullable
+ public V get(@Nonnull final Object key) {
int off = getBucketOffset(key);
- int end = sweep + off;
+ final int end = sweep + off;
for (; off < end; off++) {
if (keys[off] != null && compare(keys[off], key)) {
return values[off];
@@ -120,9 +160,10 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
return null;
}
- public V remove(Object key) {
+ @Nullable
+ public V remove(@Nonnull final Object key) {
int off = getBucketOffset(key);
- int end = sweep + off;
+ final int end = sweep + off;
for (; off < end; off++) {
if (keys[off] != null && compare(keys[off], key)) {
keys[off] = null;
@@ -139,7 +180,7 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
return size;
}
- public void putAll(Map<? extends K, ? extends V> m) {
+ public void putAll(@Nonnull final Map<? extends K, ? extends V> m) {
for (K key : m.keySet()) {
put(key, m.get(key));
}
@@ -149,11 +190,11 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
return size == 0;
}
- public boolean containsKey(Object key) {
+ public boolean containsKey(@Nonnull final Object key) {
return get(key) != null;
}
- public boolean containsValue(Object value) {
+ public boolean containsValue(@Nonnull final Object value) {
for (V v : values) {
if (v != null && compare(v, value)) {
return true;
@@ -165,11 +206,12 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
public void clear() {
Arrays.fill(keys, null);
Arrays.fill(values, null);
- size = 0;
+ this.size = 0;
}
+ @Nonnull
public Set<K> keySet() {
- Set<K> set = new HashSet<K>();
+ final Set<K> set = new HashSet<K>();
for (K key : keys) {
if (key != null) {
set.add(key);
@@ -178,8 +220,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
return set;
}
+ @Nonnull
public Collection<V> values() {
- Collection<V> list = new ArrayList<V>();
+ final Collection<V> list = new ArrayList<V>();
for (V value : values) {
if (value != null) {
list.add(value);
@@ -188,8 +231,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
return list;
}
+ @Nonnull
public Set<Entry<K, V>> entrySet() {
- Set<Entry<K, V>> set = new HashSet<Entry<K, V>>();
+ final Set<Entry<K, V>> set = new HashSet<Entry<K, V>>();
for (K key : keys) {
if (key != null) {
set.add(new MapEntry<K, V>(this, key));
@@ -207,19 +251,23 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
this.key = key;
}
+ @Override
public K getKey() {
return key;
}
+ @Override
public V getValue() {
return map.get(key);
}
+ @Override
public V setValue(V value) {
return map.put(key, value);
}
}
+ @Override
public void writeExternal(ObjectOutput out) throws IOException {
// remember the number of bits
out.writeInt(this.bits);
@@ -235,6 +283,7 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
}
@SuppressWarnings("unchecked")
+ @Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// resize to old bit size
int bitSize = in.readInt();
@@ -250,19 +299,19 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
@Override
public String toString() {
- return this.getClass().getSimpleName() + ' ' + this.size;
+ return this.getClass().getSimpleName() + ' ' + size;
}
@SuppressWarnings("unchecked")
- private void resize(int bits) {
+ private void resize(final int bits) {
this.bits = bits;
this.sweepbits = bits / 4;
this.sweep = MathUtils.powerOf(2, sweepbits) * 4;
- this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits;
+ this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits;
// remember old values so we can recreate the entries
- K[] existingKeys = this.keys;
- V[] existingValues = this.values;
+ final K[] existingKeys = this.keys;
+ final V[] existingValues = this.values;
// create the arrays
this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep];
@@ -272,31 +321,38 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
// re-add the previous entries if resizing
if (existingKeys != null) {
for (int x = 0; x < existingKeys.length; x++) {
- if (existingKeys[x] != null) {
- put(existingKeys[x], existingValues[x]);
+ final K k = existingKeys[x];
+ if (k != null) {
+ put(k, existingValues[x]);
}
}
}
}
- private int getBucketOffset(Object key) {
- return (key.hashCode() << this.sweepbits) & this.sweepmask;
+ private int getBucketOffset(@Nonnull final Object key) {
+ return (key.hashCode() << sweepbits) & sweepmask;
}
- private static boolean compare(final Object v1, final Object v2) {
+ private static boolean compare(@Nonnull final Object v1, @Nonnull final Object v2) {
return v1 == v2 || v1.equals(v2);
}
public IMapIterator<K, V> entries() {
- return new MapIterator();
+ return new MapIterator(false);
+ }
+
+ public IMapIterator<K, V> entries(boolean releaseSeen) {
+ return new MapIterator(releaseSeen);
}
private final class MapIterator implements IMapIterator<K, V> {
+ final boolean releaseSeen;
int nextEntry;
int lastEntry = -1;
- MapIterator() {
+ MapIterator(boolean releaseSeen) {
+ this.releaseSeen = releaseSeen;
this.nextEntry = nextEntry(0);
}
@@ -315,7 +371,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable {
@Override
public int next() {
- free(lastEntry);
+ if (releaseSeen) {
+ free(lastEntry);
+ }
if (!hasNext()) {
return -1;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
index 7fec9b0..4599bfc 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
@@ -27,16 +27,22 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
-import java.util.HashMap;
import javax.annotation.Nonnull;
/**
- * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}.
+ * An open-addressing hash table using double-hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
*/
public final class OpenHashTable<K, V> implements Externalizable {
- public static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ public static final float DEFAULT_LOAD_FACTOR = 0.75f;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
protected static final byte FREE = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 0b68de8..db56b82 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -289,12 +289,21 @@ public final class HiveUtils {
}
}
- @Nonnull
public static boolean isListOI(@Nonnull final ObjectInspector oi) {
Category category = oi.getCategory();
return category == Category.LIST;
}
+ public static boolean isStringListOI(@Nonnull final ObjectInspector oi)
+ throws UDFArgumentException {
+ Category category = oi.getCategory();
+ if (category != Category.LIST) {
+ throw new UDFArgumentException("Expected List OI but was: " + oi);
+ }
+ ListObjectInspector listOI = (ListObjectInspector) oi;
+ return isStringOI(listOI.getListElementObjectInspector());
+ }
+
public static boolean isMapOI(@Nonnull final ObjectInspector oi) {
return oi.getCategory() == Category.MAP;
}
@@ -670,6 +679,36 @@ public final class HiveUtils {
}
@Nullable
+ public static float[] asFloatArray(@Nullable final Object argObj,
+ @Nonnull final ListObjectInspector listOI,
+ @Nonnull final PrimitiveObjectInspector elemOI) throws UDFArgumentException {
+ return asFloatArray(argObj, listOI, elemOI, true);
+ }
+
+ @Nullable
+ public static float[] asFloatArray(@Nullable final Object argObj,
+ @Nonnull final ListObjectInspector listOI,
+ @Nonnull final PrimitiveObjectInspector elemOI, final boolean avoidNull)
+ throws UDFArgumentException {
+ if (argObj == null) {
+ return null;
+ }
+ final int length = listOI.getListLength(argObj);
+ final float[] ary = new float[length];
+ for (int i = 0; i < length; i++) {
+ Object o = listOI.getListElement(argObj, i);
+ if (o == null) {
+ if (avoidNull) {
+ continue;
+ }
+ throw new UDFArgumentException("Found null at index " + i);
+ }
+ ary[i] = PrimitiveObjectInspectorUtils.getFloat(o, elemOI);
+ }
+ return ary;
+ }
+
+ @Nullable
public static double[] asDoubleArray(@Nullable final Object argObj,
@Nonnull final ListObjectInspector listOI,
@Nonnull final PrimitiveObjectInspector elemOI) throws UDFArgumentException {
@@ -694,8 +733,7 @@ public final class HiveUtils {
}
throw new UDFArgumentException("Found null at index " + i);
}
- double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
- ary[i] = d;
+ ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
}
return ary;
}
@@ -721,8 +759,7 @@ public final class HiveUtils {
}
throw new UDFArgumentException("Found null at index " + i);
}
- double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
- out[i] = d;
+ out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
}
return;
}
@@ -746,8 +783,7 @@ public final class HiveUtils {
out[i] = nullValue;
continue;
}
- double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
- out[i] = d;
+ out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI);
}
return;
}
@@ -766,11 +802,11 @@ public final class HiveUtils {
int count = 0;
final int length = listOI.getListLength(argObj);
for (int i = 0; i < length; i++) {
- Object o = listOI.getListElement(argObj, i);
+ final Object o = listOI.getListElement(argObj, i);
if (o == null) {
continue;
}
- int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI);
+ final int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI);
if (index < 0) {
throw new UDFArgumentException("Negative index is not allowed: " + index);
}
@@ -955,6 +991,26 @@ public final class HiveUtils {
}
@Nonnull
+ public static PrimitiveObjectInspector asFloatingPointOI(@Nonnull final ObjectInspector argOI)
+ throws UDFArgumentTypeException {
+ if (argOI.getCategory() != Category.PRIMITIVE) {
+ throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but "
+ + argOI.getTypeName() + " is passed.");
+ }
+ final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI;
+ switch (oi.getPrimitiveCategory()) {
+ case FLOAT:
+ case DOUBLE:
+ break;
+ default:
+ throw new UDFArgumentTypeException(0,
+ "Only numeric or string type arguments are accepted but " + argOI.getTypeName()
+ + " is passed.");
+ }
+ return oi;
+ }
+
+ @Nonnull
public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentTypeException {
if (argOI.getCategory() != Category.PRIMITIVE) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/hashing/HashUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hashing/HashUtils.java b/core/src/main/java/hivemall/utils/hashing/HashUtils.java
new file mode 100644
index 0000000..710d8f6
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/hashing/HashUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.hashing;
+
+public final class HashUtils {
+
+ private HashUtils() {}
+
+ public static int jenkins32(int k) {
+ k = (k + 0x7ed55d16) + (k << 12);
+ k = (k ^ 0xc761c23c) ^ (k >> 19);
+ k = (k + 0x165667b1) + (k << 5);
+ k = (k + 0xd3a2646c) ^ (k << 9);
+ k = (k + 0xfd7046c5) + (k << 3);
+ k = (k ^ 0xb55a4f09) ^ (k >> 16);
+ return k;
+ }
+
+ public static int murmurHash3(int k) {
+ k ^= k >>> 16;
+ k *= 0x85ebca6b;
+ k ^= k >>> 13;
+ k *= 0xc2b2ae35;
+ k ^= k >>> 16;
+ return k;
+ }
+
+ public static int fnv1a(final int k) {
+ int hash = 0x811c9dc5;
+ for (int i = 0; i < 4; i++) {
+ hash ^= k << (i * 8);
+ hash *= 0x01000193;
+ }
+ return hash;
+ }
+
+ /**
+ * https://gist.github.com/badboy/6267743
+ */
+ public static int hash32shift(int k) {
+ k = ~k + (k << 15); // key = (key << 15) - key - 1;
+ k = k ^ (k >>> 12);
+ k = k + (k << 2);
+ k = k ^ (k >>> 4);
+ k = k * 2057; // key = (key + (key << 3)) + (key << 11);
+ k = k ^ (k >>> 16);
+ return k;
+ }
+
+ public static int hash32shiftmult(int k) {
+ k = (k ^ 61) ^ (k >>> 16);
+ k = k + (k << 3);
+ k = k ^ (k >>> 4);
+ k = k * 0x27d4eb2d;
+ k = k ^ (k >>> 15);
+ return k;
+ }
+
+ /**
+ * http://burtleburtle.net/bob/hash/integer.html
+ */
+ public static int hash7shifts(int k) {
+ k -= (k << 6);
+ k ^= (k >> 17);
+ k -= (k << 9);
+ k ^= (k << 4);
+ k -= (k << 3);
+ k ^= (k << 10);
+ k ^= (k >> 15);
+ return k;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/lang/NumberUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/NumberUtils.java b/core/src/main/java/hivemall/utils/lang/NumberUtils.java
index 0d3f895..4b04f04 100644
--- a/core/src/main/java/hivemall/utils/lang/NumberUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/NumberUtils.java
@@ -107,4 +107,72 @@ public final class NumberUtils {
return true;
}
+ /**
+ * @throws ArithmeticException
+ */
+ public static int castToInt(final long value) {
+ final int result = (int) value;
+ if (result != value) {
+ throw new ArithmeticException("Out of range: " + value);
+ }
+ return result;
+ }
+
+ /**
+ * @throws ArithmeticException
+ */
+ public static short castToShort(final int value) {
+ final short result = (short) value;
+ if (result != value) {
+ throw new ArithmeticException("Out of range: " + value);
+ }
+ return result;
+ }
+
+ /**
+ * Cast Double to Float.
+ *
+ * @throws ArithmeticException
+ */
+ public static float castToFloat(final double v) {
+ if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) {
+ throw new ArithmeticException("Double value is out of Float range: " + v);
+ }
+ return (float) v;
+ }
+
+ /**
+ * Cast Double to Float.
+ *
+ * @return v if v is Float range; Float.MIN_VALUE or Float.MAX_VALUE otherwise
+ */
+ public static float safeCast(final double v) {
+ if (v < Float.MIN_VALUE) {
+ return Float.MIN_VALUE;
+ } else if (v > Float.MAX_VALUE) {
+ return Float.MAX_VALUE;
+ }
+ return (float) v;
+ }
+
+ /**
+ * Cast Double to Float.
+ *
+ * @return v if v is Float range; defaultValue otherwise
+ */
+ public static float safeCast(final double v, final float defaultValue) {
+ if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) {
+ return defaultValue;
+ }
+ return (float) v;
+ }
+
+ public static int toUnsignedShort(final short v) {
+ return v & 0xFFFF; // convert to range 0-65535 from -32768-32767.
+ }
+
+ public static int toUnsignedInt(final byte x) {
+ return ((int) x) & 0xff;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index 2ec012c..7d43da1 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -26,14 +26,6 @@ public final class Primitives {
private Primitives() {}
- public static int toUnsignedShort(final short v) {
- return v & 0xFFFF; // convert to range 0-65535 from -32768-32767.
- }
-
- public static int toUnsignedInt(final byte x) {
- return ((int) x) & 0xff;
- }
-
public static short parseShort(final String s, final short defaultValue) {
if (s == null) {
return defaultValue;
@@ -92,22 +84,6 @@ public final class Primitives {
b[off] = (byte) (val >>> 8);
}
- public static int toIntExact(final long longValue) {
- final int casted = (int) longValue;
- if (casted != longValue) {
- throw new ArithmeticException("integer overflow: " + longValue);
- }
- return casted;
- }
-
- public static int castToInt(final long value) {
- final int result = (int) value;
- if (result != value) {
- throw new IllegalArgumentException("Out of range: " + value);
- }
- return result;
- }
-
public static long toLong(final int high, final int low) {
return ((long) high << 32) | ((long) low & 0xffffffffL);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 3f41b6f..6162adb 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -264,7 +264,7 @@ public final class MathUtils {
return r;
}
- public static boolean equals(@Nonnull final float value, final float expected, final float delta) {
+ public static boolean equals(final float value, final float expected, final float delta) {
if (Double.isNaN(value)) {
return false;
}
@@ -274,8 +274,7 @@ public final class MathUtils {
return true;
}
- public static boolean equals(@Nonnull final double value, final double expected,
- final double delta) {
+ public static boolean equals(final double value, final double expected, final double delta) {
if (Double.isNaN(value)) {
return false;
}
@@ -285,26 +284,34 @@ public final class MathUtils {
return true;
}
- public static boolean almostEquals(@Nonnull final float value, final float expected) {
+ public static boolean almostEquals(final float value, final float expected) {
return equals(value, expected, 1E-15f);
}
- public static boolean almostEquals(@Nonnull final double value, final double expected) {
+ public static boolean almostEquals(final double value, final double expected) {
return equals(value, expected, 1E-15d);
}
- public static boolean closeToZero(@Nonnull final float value) {
- if (Math.abs(value) > 1E-15f) {
- return false;
+ public static boolean closeToZero(final float value) {
+ return closeToZero(value, 1E-15f);
+ }
+
+ public static boolean closeToZero(final float value, @Nonnegative final float tol) {
+ if (value == 0.f) {
+ return true;
}
- return true;
+ return Math.abs(value) <= tol;
}
- public static boolean closeToZero(@Nonnull final double value) {
- if (Math.abs(value) > 1E-15d) {
- return false;
+ public static boolean closeToZero(final double value) {
+ return closeToZero(value, 1E-15d);
+ }
+
+ public static boolean closeToZero(final double value, @Nonnegative final double tol) {
+ if (value == 0.d) {
+ return true;
}
- return true;
+ return Math.abs(value) <= tol;
}
public static double sign(final double x) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
deleted file mode 100644
index 076387f..0000000
--- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class FFMPredictionModelTest {
-
- @Test
- public void testSerialize() throws IOException, ClassNotFoundException {
- final int factors = 3;
- final int entrySize = Entry.sizeOf(factors);
-
- HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
- Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance();
-
- Entry e1 = new Entry(buf, factors, buf.allocate(entrySize));
- e1.setW(1f);
- e1.setV(new float[] {1f, -1f, -1f});
-
- Entry e2 = new Entry(buf, factors, buf.allocate(entrySize));
- e2.setW(2f);
- e2.setV(new float[] {1f, 2f, -1f});
-
- Entry e3 = new Entry(buf, factors, buf.allocate(entrySize));
- e3.setW(3f);
- e3.setV(new float[] {1f, 2f, 3f});
-
- map.put(1, e1.getOffset());
- map.put(2, e2.getOffset());
- map.put(3, e3.getOffset());
-
- FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3,
- Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS);
- byte[] b = expected.serialize();
-
- FFMPredictionModel actual = FFMPredictionModel.deserialize(b, b.length);
- Assert.assertEquals(3, actual.getNumFactors());
- Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, actual.getNumFeatures());
- Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FeatureTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FeatureTest.java b/core/src/test/java/hivemall/fm/FeatureTest.java
index 25e5671..911a4a5 100644
--- a/core/src/test/java/hivemall/fm/FeatureTest.java
+++ b/core/src/test/java/hivemall/fm/FeatureTest.java
@@ -34,7 +34,7 @@ public class FeatureTest {
@Test
public void testParseFFMFeature() throws HiveException {
- IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651");
+ IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651", -1);
Assert.assertEquals(2, f1.getField());
Assert.assertEquals(1163, f1.getFeatureIndex());
Assert.assertEquals("1163", f1.getFeature());
@@ -85,4 +85,9 @@ public class FeatureTest {
Feature.parseFeature("2:1163:0.3651", true);
}
+ @Test(expected = HiveException.class)
+ public void testParseFeatureZeroIndex() throws HiveException {
+ Feature.parseFFMFeature("0:0.3652");
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 792ede1..3b219c6 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -23,11 +23,11 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
+import java.util.List;
import java.util.zip.GZIPInputStream;
import javax.annotation.Nonnull;
-import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -44,32 +44,29 @@ public class FieldAwareFactorizationMachineUDTFTest {
@Test
public void testSGD() throws HiveException, IOException {
- runTest("Pure SGD test",
- "-classification -factors 10 -w0 -seed 43 -disable_adagrad -disable_ftrl", 0.60f);
+ runTest("Pure SGD test", "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f);
}
@Test
- public void testSGDWithFTRL() throws HiveException, IOException {
- runTest("SGD w/ FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_adagrad",
- 0.60f);
+ public void testAdaGrad() throws HiveException, IOException {
+ runTest("AdaGrad test", "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f);
}
@Test
public void testAdaGradNoCoeff() throws HiveException, IOException {
- runTest("AdaGrad No Coeff test", "-classification -factors 10 -w0 -seed 43 -no_coeff",
- 0.30f);
+ runTest("AdaGrad No Coeff test",
+ "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f);
}
@Test
- public void testAdaGradNoFTRL() throws HiveException, IOException {
- runTest("AdaGrad w/o FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_ftrl",
- 0.30f);
+ public void testFTRL() throws HiveException, IOException {
+ runTest("FTRL test", "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f);
}
@Test
- public void testAdaGradDefault() throws HiveException, IOException {
- runTest("AdaGrad DEFAULT (adagrad for V + FTRL for W)",
- "-classification -factors 10 -w0 -seed 43", 0.30f);
+ public void testFTRLNoCoeff() throws HiveException, IOException {
+ runTest("FTRL Coeff test", "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43",
+ 0.30f);
}
private static void runTest(String testName, String testOptions, float lossThreshold)
@@ -100,30 +97,22 @@ public class FieldAwareFactorizationMachineUDTFTest {
if (input == null) {
break;
}
- ArrayList<String> featureStrings = new ArrayList<String>();
- ArrayList<StringFeature> features = new ArrayList<StringFeature>();
-
- //make StringFeature for each word = data point
- String remaining = input;
- int wordCut = remaining.indexOf(' ');
- while (wordCut != -1) {
- featureStrings.add(remaining.substring(0, wordCut));
- remaining = remaining.substring(wordCut + 1);
- wordCut = remaining.indexOf(' ');
- }
- int end = featureStrings.size();
- double y = Double.parseDouble(featureStrings.get(0));
+ String[] featureStrings = input.split(" ");
+
+ double y = Double.parseDouble(featureStrings[0]);
if (y == 0) {
y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1}
}
- for (int wordNumber = 1; wordNumber < end; ++wordNumber) {
- String entireFeature = featureStrings.get(wordNumber);
- int featureCut = StringUtils.ordinalIndexOf(entireFeature, ":", 2);
- String feature = entireFeature.substring(0, featureCut);
- double value = Double.parseDouble(entireFeature.substring(featureCut + 1));
- features.add(new StringFeature(feature, value));
+
+ final List<String> features = new ArrayList<String>(featureStrings.length - 1);
+ for (int j = 1; j < featureStrings.length; ++j) {
+ String[] splitted = featureStrings[j].split(":");
+ Assert.assertEquals(3, splitted.length);
+ int index = Integer.parseInt(splitted[1]) + 1;
+ String f = splitted[0] + ':' + index + ':' + splitted[2];
+ features.add(f);
}
- udtf.process(new Object[] {toStringArray(features), y});
+ udtf.process(new Object[] {features, y});
}
cumul = udtf._cvState.getCumulativeLoss();
loss = (cumul - loss) / lines;
@@ -143,15 +132,6 @@ public class FieldAwareFactorizationMachineUDTFTest {
return new BufferedReader(new InputStreamReader(is));
}
- private static String[] toStringArray(ArrayList<StringFeature> x) {
- final int size = x.size();
- final String[] ret = new String[size];
- for (int i = 0; i < size; i++) {
- ret[i] = x.get(i).toString();
- }
- return ret;
- }
-
private static void println(String line) {
if (DEBUG) {
System.out.println(line);