You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/11 05:39:43 UTC
[4/4] incubator-hivemall git commit: Close #105,
Close #58: [HIVEMALL-24-2] Make ffm_predict function more scalable by
creating its UDAF implementation
Close #105, Close #58: [HIVEMALL-24-2] Make ffm_predict function more scalable by creating its UDAF implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/38047891
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/38047891
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/38047891
Branch: refs/heads/master
Commit: 3804789168dab5c5d43aac1fd4000e07688c6a06
Parents: 7205de1
Author: Makoto Yui <my...@apache.org>
Authored: Mon Sep 11 14:36:12 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Mon Sep 11 14:38:12 2017 +0900
----------------------------------------------------------------------
.travis.yml | 2 +-
.../java/hivemall/common/ConversionState.java | 21 +-
core/src/main/java/hivemall/fm/Entry.java | 242 +++++++---
.../java/hivemall/fm/FFMPredictGenericUDAF.java | 262 +++++++++++
.../main/java/hivemall/fm/FFMPredictUDF.java | 187 --------
.../java/hivemall/fm/FFMPredictionModel.java | 349 --------------
.../hivemall/fm/FFMStringFeatureMapModel.java | 315 ++++++++-----
.../java/hivemall/fm/FMHyperParameters.java | 74 +--
.../java/hivemall/fm/FMIntFeatureMapModel.java | 6 +-
.../java/hivemall/fm/FMPredictGenericUDAF.java | 15 +
.../hivemall/fm/FMStringFeatureMapModel.java | 8 +-
.../hivemall/fm/FactorizationMachineUDTF.java | 6 +-
core/src/main/java/hivemall/fm/Feature.java | 76 ++-
.../fm/FieldAwareFactorizationMachineModel.java | 161 ++++++-
.../fm/FieldAwareFactorizationMachineUDTF.java | 158 ++++---
core/src/main/java/hivemall/fm/IntFeature.java | 6 +-
.../ftvec/pairing/FeaturePairsUDTF.java | 155 ++++--
.../ftvec/ranking/PositiveOnlyFeedback.java | 8 +-
.../ftvec/trans/AddFieldIndicesUDF.java | 89 ++++
.../ftvec/trans/CategoricalFeaturesUDF.java | 121 +++--
.../hivemall/ftvec/trans/FFMFeaturesUDF.java | 47 +-
.../ftvec/trans/QuantifiedFeaturesUDTF.java | 7 +-
.../ftvec/trans/QuantitativeFeaturesUDF.java | 101 +++-
.../ftvec/trans/VectorizeFeaturesUDF.java | 110 +++--
.../main/java/hivemall/mf/FactorizedModel.java | 18 +-
.../hivemall/model/AbstractPredictionModel.java | 6 +-
.../java/hivemall/model/NewSparseModel.java | 2 +-
.../main/java/hivemall/model/SparseModel.java | 2 +-
.../tools/array/ArrayAvgGenericUDAF.java | 17 +-
.../java/hivemall/utils/buffer/HeapBuffer.java | 37 +-
.../maps/Int2FloatOpenHashTable.java | 11 +-
.../collections/maps/Int2IntOpenHashTable.java | 9 +-
.../collections/maps/Int2LongOpenHashMap.java | 346 ++++++++++++++
.../collections/maps/Int2LongOpenHashTable.java | 114 +++--
.../utils/collections/maps/IntOpenHashMap.java | 467 -------------------
.../collections/maps/IntOpenHashTable.java | 142 ++++--
.../maps/Long2DoubleOpenHashTable.java | 9 +-
.../maps/Long2FloatOpenHashTable.java | 11 +-
.../collections/maps/Long2IntOpenHashTable.java | 9 +-
.../utils/collections/maps/OpenHashMap.java | 128 +++--
.../utils/collections/maps/OpenHashTable.java | 12 +-
.../java/hivemall/utils/hadoop/HiveUtils.java | 74 ++-
.../java/hivemall/utils/hashing/HashUtils.java | 89 ++++
.../java/hivemall/utils/lang/NumberUtils.java | 68 +++
.../java/hivemall/utils/lang/Primitives.java | 24 -
.../java/hivemall/utils/math/MathUtils.java | 33 +-
.../hivemall/fm/FFMPredictionModelTest.java | 65 ---
core/src/test/java/hivemall/fm/FeatureTest.java | 7 +-
.../FieldAwareFactorizationMachineUDTFTest.java | 66 +--
.../smile/tools/TreePredictUDFv1Test.java | 1 +
.../maps/Int2FloatOpenHashMapTest.java | 98 ----
.../maps/Int2FloatOpenHashTableTest.java | 98 ++++
.../maps/Int2LongOpenHashMapTest.java | 66 +--
.../maps/Int2LongOpenHashTableTest.java | 130 ++++++
.../collections/maps/IntOpenHashMapTest.java | 75 ---
.../collections/maps/IntOpenHashTableTest.java | 23 +
.../maps/Long2IntOpenHashMapTest.java | 115 -----
.../maps/Long2IntOpenHashTableTest.java | 115 +++++
docs/gitbook/getting_started/input-format.md | 31 +-
pom.xml | 18 +
resources/ddl/define-all-as-permanent.hive | 5 +-
resources/ddl/define-all.hive | 5 +-
resources/ddl/define-all.spark | 5 +-
resources/ddl/define-udfs.td.hql | 3 +
spark/spark-2.0/pom.xml | 10 +-
spark/spark-2.1/pom.xml | 10 +-
spark/spark-common/pom.xml | 8 +-
67 files changed, 2962 insertions(+), 2146 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index 323e36a..96f8f4e 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -19,7 +19,7 @@ env:
language: java
jdk:
- openjdk7
- - oraclejdk7
+# - oraclejdk7
- oraclejdk8
branches:
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/common/ConversionState.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java
index 7b5923f..435bf75 100644
--- a/core/src/main/java/hivemall/common/ConversionState.java
+++ b/core/src/main/java/hivemall/common/ConversionState.java
@@ -99,18 +99,25 @@ public final class ConversionState {
if (changeRate < convergenceRate) {
if (readyToFinishIterations) {
// NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY
- logger.info("Training converged at " + curIter + "-th iteration. [curLosses="
- + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
- + ']');
+ if (logger.isInfoEnabled()) {
+ logger.info("Training converged at " + curIter + "-th iteration. [curLosses="
+ + currLosses + ", prevLosses=" + prevLosses + ", changeRate="
+ + changeRate + ']');
+ }
return true;
} else {
+ if (logger.isInfoEnabled()) {
+ logger.info("Iteration #" + curIter + " [curLosses=" + currLosses
+ + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
+ + ", #trainingExamples=" + observedTrainingExamples + ']');
+ }
this.readyToFinishIterations = true;
}
} else {
- if (logger.isDebugEnabled()) {
- logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses
- + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
- + ", #trainingExamples=" + observedTrainingExamples + ']');
+ if (logger.isInfoEnabled()) {
+ logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses="
+ + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples="
+ + observedTrainingExamples + ']');
}
this.readyToFinishIterations = false;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/Entry.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Entry.java b/core/src/main/java/hivemall/fm/Entry.java
index 1882f85..974ab5b 100644
--- a/core/src/main/java/hivemall/fm/Entry.java
+++ b/core/src/main/java/hivemall/fm/Entry.java
@@ -20,17 +20,27 @@ package hivemall.fm;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
+import hivemall.utils.math.MathUtils;
+import java.util.Arrays;
+
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
class Entry {
@Nonnull
protected final HeapBuffer _buf;
+ @Nonnegative
protected final int _size;
+ @Nonnegative
protected final int _factors;
+ // temporary variables used only in training phase
+ protected int _key;
+ @Nonnegative
protected long _offset;
Entry(@Nonnull HeapBuffer buf, int factors) {
@@ -39,128 +49,210 @@ class Entry {
this._factors = factors;
}
- Entry(@Nonnull HeapBuffer buf, int factors, long offset) {
- this(buf, factors, Entry.sizeOf(factors), offset);
+ Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
+ this(buf, 1, key, offset);
+ }
+
+ Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) {
+ this(buf, factors, Entry.sizeOf(factors), key, offset);
}
- private Entry(@Nonnull HeapBuffer buf, int factors, int size, long offset) {
+ private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key, @Nonnegative long offset) {
this._buf = buf;
this._size = size;
this._factors = factors;
- setOffset(offset);
+ this._key = key;
+ this._offset = offset;
}
- int getSize() {
+ final int getSize() {
return _size;
}
- long getOffset() {
+ final int getKey() {
+ return _key;
+ }
+
+ final long getOffset() {
return _offset;
}
- void setOffset(long offset) {
+ final void setOffset(final long offset) {
this._offset = offset;
}
- float getW() {
+ final float getW() {
return _buf.getFloat(_offset);
}
- void setW(final float value) {
+ final void setW(final float value) {
_buf.putFloat(_offset, value);
}
- void getV(@Nonnull final float[] Vf) {
- final long offset = _offset + SizeOf.FLOAT;
+ final void getV(@Nonnull final float[] Vf) {
+ final long offset = _offset;
final int len = Vf.length;
- for (int i = 0; i < len; i++) {
- Vf[i] = _buf.getFloat(offset + SizeOf.FLOAT * i);
+ for (int f = 0; f < len; f++) {
+ long index = offset + SizeOf.FLOAT * f;
+ Vf[f] = _buf.getFloat(index);
}
}
- void setV(@Nonnull final float[] Vf) {
- final long offset = _offset + SizeOf.FLOAT;
+ final void setV(@Nonnull final float[] Vf) {
+ final long offset = _offset;
final int len = Vf.length;
- for (int i = 0; i < len; i++) {
- _buf.putFloat(offset + SizeOf.FLOAT * i, Vf[i]);
+ for (int f = 0; f < len; f++) {
+ long index = offset + SizeOf.FLOAT * f;
+ _buf.putFloat(index, Vf[f]);
}
}
- float getV(final int f) {
- return _buf.getFloat(_offset + SizeOf.FLOAT + SizeOf.FLOAT * f);
+ final float getV(final int f) {
+ long index = _offset + SizeOf.FLOAT * f;
+ return _buf.getFloat(index);
}
- void setV(final int f, final float value) {
- long index = _offset + SizeOf.FLOAT + SizeOf.FLOAT * f;
+ final void setV(final int f, final float value) {
+ long index = _offset + SizeOf.FLOAT * f;
_buf.putFloat(index, value);
}
- double getSumOfSquaredGradientsV() {
+ double getSumOfSquaredGradients(@Nonnegative int f) {
throw new UnsupportedOperationException();
}
- void addGradientV(float grad) {
+ void addGradient(@Nonnegative int f, float grad) {
throw new UnsupportedOperationException();
}
- float updateZ(float gradW, float alpha) {
+ final float updateZ(final float gradW, final float alpha) {
+ float w = getW();
+ return updateZ(0, w, gradW, alpha);
+ }
+
+ float updateZ(@Nonnegative int f, float W, float gradW, float alpha) {
throw new UnsupportedOperationException();
}
- double updateN(float gradW) {
+ final double updateN(final float gradW) {
+ return updateN(0, gradW);
+ }
+
+ double updateN(@Nonnegative int f, float gradW) {
throw new UnsupportedOperationException();
}
- static int sizeOf(int factors) {
- return SizeOf.FLOAT + SizeOf.FLOAT * factors;
+ boolean removable() {
+ if (!isEntryW(_key)) {// entry for V
+ final long offset = _offset;
+ for (int f = 0; f < _factors; f++) {
+ final float Vf = _buf.getFloat(offset + SizeOf.FLOAT * f);
+ if (!MathUtils.closeToZero(Vf, 1E-9f)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ void clear() {};
+
+ static int sizeOf(@Nonnegative final int factors) {
+ Preconditions.checkArgument(factors >= 1, "Factors must be greather than 0: " + factors);
+ return SizeOf.FLOAT * factors;
+ }
+
+ static boolean isEntryW(final int i) {
+ return i < 0;
+ }
+
+ @Override
+ public String toString() {
+ if (Entry.isEntryW(_key)) {
+ return "W=" + getW();
+ } else {
+ float[] Vf = new float[_factors];
+ getV(Vf);
+ return "V=" + Arrays.toString(Vf);
+ }
}
- static class AdaGradEntry extends Entry {
+ static final class AdaGradEntry extends Entry {
final long _gg_offset;
- AdaGradEntry(@Nonnull HeapBuffer buf, int factors, long offset) {
- super(buf, factors, AdaGradEntry.sizeOf(factors), offset);
- this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors;
+ AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
+ this(buf, 1, key, offset);
}
- private AdaGradEntry(@Nonnull HeapBuffer buf, int factors, int size, long offset) {
- super(buf, factors, size, offset);
- this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors;
+ AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key,
+ @Nonnegative long offset) {
+ super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset);
+ this._gg_offset = _offset + Entry.sizeOf(factors);
}
@Override
- double getSumOfSquaredGradientsV() {
- return _buf.getDouble(_gg_offset);
+ double getSumOfSquaredGradients(@Nonnegative final int f) {
+ Preconditions.checkArgument(f >= 0);
+
+ long offset = _gg_offset + SizeOf.DOUBLE * f;
+ return _buf.getDouble(offset);
}
@Override
- void addGradientV(float grad) {
- double v = _buf.getDouble(_gg_offset);
+ void addGradient(@Nonnegative final int f, final float grad) {
+ Preconditions.checkArgument(f >= 0);
+
+ long offset = _gg_offset + SizeOf.DOUBLE * f;
+ double v = _buf.getDouble(offset);
v += grad * grad;
- _buf.putDouble(_gg_offset, v);
+ _buf.putDouble(offset, v);
}
- static int sizeOf(int factors) {
- return Entry.sizeOf(factors) + SizeOf.DOUBLE;
+ @Override
+ void clear() {
+ for (int f = 0; f < _factors; f++) {
+ long offset = _gg_offset + SizeOf.DOUBLE * f;
+ _buf.putDouble(offset, 0.d);
+ }
+ }
+
+ static int sizeOf(@Nonnegative final int factors) {
+ return Entry.sizeOf(factors) + SizeOf.DOUBLE * factors;
+ }
+
+ @Override
+ public String toString() {
+ final double[] gg = new double[_factors];
+ for (int f = 0; f < _factors; f++) {
+ gg[f] = getSumOfSquaredGradients(f);
+ }
+ return super.toString() + ", gg=" + Arrays.toString(gg);
}
}
- static final class FTRLEntry extends AdaGradEntry {
+ static final class FTRLEntry extends Entry {
final long _z_offset;
- FTRLEntry(@Nonnull HeapBuffer buf, int factors, long offset) {
- super(buf, factors, FTRLEntry.sizeOf(factors), offset);
- this._z_offset = _gg_offset + SizeOf.DOUBLE;
+ FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) {
+ this(buf, 1, key, offset);
+ }
+
+ FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) {
+ super(buf, factors, FTRLEntry.sizeOf(factors), key, offset);
+ this._z_offset = _offset + Entry.sizeOf(factors);
}
@Override
- float updateZ(float gradW, float alpha) {
- final float W = getW();
- final float z = getZ();
- final double n = getN();
+ float updateZ(final int f, final float W, final float gradW, final float alpha) {
+ Preconditions.checkArgument(f >= 0);
+
+ final long zOffset = offsetZ(f);
+
+ final float z = _buf.getFloat(zOffset);
+ final double n = _buf.getFloat(offsetN(f)); // implicit cast to float
double gg = gradW * gradW;
float sigma = (float) ((Math.sqrt(n + gg) - Math.sqrt(n)) / alpha);
@@ -171,44 +263,56 @@ class Entry {
+ gradW + ", sigma=" + sigma + ", W=" + W + ", n=" + n + ", gg=" + gg
+ ", alpha=" + alpha);
}
- setZ(newZ);
+ _buf.putFloat(zOffset, newZ);
return newZ;
}
- private float getZ() {
- return _buf.getFloat(_z_offset);
- }
-
- private void setZ(final float value) {
- _buf.putFloat(_z_offset, value);
- }
-
@Override
- double updateN(final float gradW) {
- final double n = getN();
+ double updateN(final int f, final float gradW) {
+ Preconditions.checkArgument(f >= 0);
+
+ final long nOffset = offsetN(f);
+ final double n = _buf.getFloat(nOffset);
final double newN = n + gradW * gradW;
if (!NumberUtils.isFinite(newN)) {
throw new IllegalStateException("Got newN " + newN + " where n=" + n + ", gradW="
+ gradW);
}
- setN(newN);
+ _buf.putFloat(nOffset, NumberUtils.castToFloat(newN)); // cast may throw ArithmeticException
return newN;
}
- private double getN() {
- long index = _z_offset + SizeOf.FLOAT;
- return _buf.getDouble(index);
+ private long offsetZ(@Nonnegative final int f) {
+ return _z_offset + SizeOf.FLOAT * f;
}
- private void setN(final double value) {
- long index = _z_offset + SizeOf.FLOAT;
- _buf.putDouble(index, value);
+ private long offsetN(@Nonnegative final int f) {
+ return _z_offset + SizeOf.FLOAT * (_factors + f);
}
- static int sizeOf(int factors) {
- return AdaGradEntry.sizeOf(factors) + SizeOf.FLOAT + SizeOf.DOUBLE;
+ @Override
+ void clear() {
+ for (int f = 0; f < _factors; f++) {
+ _buf.putFloat(offsetZ(f), 0.f);
+ _buf.putFloat(offsetN(f), 0.f);
+ }
}
+ static int sizeOf(@Nonnegative final int factors) {
+ return Entry.sizeOf(factors) + (SizeOf.FLOAT + SizeOf.FLOAT) * factors;
+ }
+
+ @Override
+ public String toString() {
+ final float[] Z = new float[_factors];
+ final float[] N = new float[_factors];
+ for (int f = 0; f < _factors; f++) {
+ Z[f] = _buf.getFloat(offsetZ(f));
+ N[f] = _buf.getFloat(offsetN(f));
+ }
+ return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N);
+ }
}
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
new file mode 100644
index 0000000..7cbd688
--- /dev/null
+++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.fm;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.SizeOf;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+
+@Description(name = "ffm_predict",
+ value = "_FUNC_(float Wi, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)"
+ + " - Returns a prediction value in Double")
+public final class FFMPredictGenericUDAF extends AbstractGenericUDAFResolver {
+
+ private FFMPredictGenericUDAF() {}
+
+ @Override
+ public Evaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 5) {
+ throw new UDFArgumentLengthException(
+ "Expected argument length is 5 but given argument length was " + typeInfo.length);
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName());
+ }
+ if (typeInfo[1].getCategory() != Category.LIST) {
+ throw new UDFArgumentTypeException(1,
+ "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName());
+ }
+ if (typeInfo[2].getCategory() != Category.LIST) {
+ throw new UDFArgumentTypeException(2,
+ "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName());
+ }
+ ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1];
+ if (!HiveUtils.isFloatingPointTypeInfo(typeInfo1.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "Double or Float type is expected for the element type of list Vifj: "
+ + typeInfo1.getTypeName());
+ }
+ ListTypeInfo typeInfo2 = (ListTypeInfo) typeInfo[2];
+ if (!HiveUtils.isFloatingPointTypeInfo(typeInfo2.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(2,
+ "Double or Float type is expected for the element type of list Vjfi: "
+ + typeInfo1.getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+ throw new UDFArgumentTypeException(3,
+ "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) {
+ throw new UDFArgumentTypeException(4,
+ "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName());
+ }
+ return new Evaluator();
+ }
+
+ public static final class Evaluator extends GenericUDAFEvaluator {
+
+ // input OI
+ private PrimitiveObjectInspector wiOI;
+ private ListObjectInspector vijOI, vjiOI;
+ private PrimitiveObjectInspector vijElemOI, vjiElemOI;
+ private PrimitiveObjectInspector xiOI, xjOI;
+
+ // merge input OI
+ private DoubleObjectInspector mergeInputOI;
+
+ public Evaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 5);
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
+ this.vijOI = HiveUtils.asListOI(parameters[1]);
+ this.vijElemOI = HiveUtils.asFloatingPointOI(vijOI.getListElementObjectInspector());
+ this.vjiOI = HiveUtils.asListOI(parameters[2]);
+ this.vjiElemOI = HiveUtils.asFloatingPointOI(vjiOI.getListElementObjectInspector());
+ this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+ this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters[4]);
+ } else {// from partial aggregation
+ this.mergeInputOI = HiveUtils.asDoubleOI(parameters[0]);
+ }
+
+ return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+
+ @Override
+ public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ final FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+
+ if (parameters[0] == null) {// Wi is null
+ if (parameters[3] == null || parameters[4] == null) {
+ // both Xi and Xj are nonnull => <Vifj, Vjfi> Xi Xj
+ return;
+ }
+ if (parameters[1] == null || parameters[2] == null) {
+ // vi, vj can be null where feature index does not exist in the prediction model
+ return;
+ }
+
+ // (i, j, xi, xj) => (wi, vi, vj, xi, xj)
+ float[] vij = HiveUtils.asFloatArray(parameters[1], vijOI, vijElemOI, false);
+ float[] vji = HiveUtils.asFloatArray(parameters[2], vjiOI, vjiElemOI, false);
+ double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+ double xj = PrimitiveObjectInspectorUtils.getDouble(parameters[4], xjOI);
+
+ myAggr.addViVjXiXj(vij, vji, xi, xj);
+ } else {
+ final double wi = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wiOI);
+
+ if (parameters[3] == null && parameters[4] == null) {// Xi and Xj are null => global bias `w0`
+ // (i=0, j=null, xi=null, xj=null) => (wi, vi=?, vj=null, xi=null, xj=null)
+ myAggr.addW0(wi);
+ } else if (parameters[4] == null) {// Only Xi is nonnull => linear combination `wi` * `xi`
+ // (i, j=null, xi, xj=null) => (wi, vi, vj=null, xi, xj=null)
+ double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+ myAggr.addWiXi(wi, xi);
+ }
+ }
+ }
+
+ @Override
+ public DoubleWritable terminatePartial(
+ @SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ double sum = myAggr.get();
+ return new DoubleWritable(sum);
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+ throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ double sum = mergeInputOI.get(partial);
+ myAggr.merge(sum);
+ }
+
+ @Override
+ public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ @AggregationType(estimable = true)
+ public static final class FFMPredictAggregationBuffer extends AbstractAggregationBuffer {
+
+ private double sum;
+
+ FFMPredictAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.sum = 0.d;
+ }
+
+ void merge(double o_sum) {
+ this.sum += o_sum;
+ }
+
+ double get() {
+ return sum;
+ }
+
+ void addW0(final double W0) {
+ this.sum += W0;
+ }
+
+ void addWiXi(final double Wi, final double Xi) {
+ this.sum += (Wi * Xi);
+ }
+
+ void addViVjXiXj(@Nonnull final float[] Vij, @Nonnull final float[] Vji, final double Xi,
+ final double Xj) throws UDFArgumentException {
+ if (Vij.length != Vji.length) {
+ throw new UDFArgumentException("Mismatch in the number of factors");
+ }
+
+ final int factors = Vij.length;
+
+ // compute inner product <Vifj, Vjfi>
+ double prod = 0.d;
+ for (int f = 0; f < factors; f++) {
+ prod += (Vij[f] * Vji[f]);
+ }
+
+ this.sum += (prod * Xi * Xj);
+ }
+
+ @Override
+ public int estimate() {
+ return SizeOf.DOUBLE;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FFMPredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java b/core/src/main/java/hivemall/fm/FFMPredictUDF.java
deleted file mode 100644
index 48745d9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.annotations.Experimental;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.NumberUtils;
-
-import java.io.IOException;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
-import org.apache.hadoop.io.Text;
-
-/**
- * @since v0.5-rc.1
- */
-@Description(name = "ffm_predict",
- value = "_FUNC_(string modelId, string model, array<string> features)"
- + " returns a prediction result in double from a Field-aware Factorization Machine")
-@UDFType(deterministic = true, stateful = false)
-@Experimental
-public final class FFMPredictUDF extends GenericUDF {
-
- private StringObjectInspector _modelIdOI;
- private StringObjectInspector _modelOI;
- private ListObjectInspector _featureListOI;
-
- private DoubleWritable _result;
- @Nullable
- private String _cachedModeId;
- @Nullable
- private FFMPredictionModel _cachedModel;
- @Nullable
- private Feature[] _probes;
-
- public FFMPredictUDF() {}
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length != 3) {
- throw new UDFArgumentException("_FUNC_ takes 3 arguments");
- }
- this._modelIdOI = HiveUtils.asStringOI(argOIs[0]);
- this._modelOI = HiveUtils.asStringOI(argOIs[1]);
- this._featureListOI = HiveUtils.asListOI(argOIs[2]);
-
- this._result = new DoubleWritable();
- return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
- }
-
- @Override
- public Object evaluate(DeferredObject[] args) throws HiveException {
- String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get());
- if (modelId == null) {
- throw new HiveException("modelId is not set");
- }
-
- final FFMPredictionModel model;
- if (modelId.equals(_cachedModeId)) {
- model = this._cachedModel;
- } else {
- Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get());
- if (serModel == null) {
- throw new HiveException("Model is null for model ID: " + modelId);
- }
- byte[] b = serModel.getBytes();
- final int length = serModel.getLength();
- try {
- model = FFMPredictionModel.deserialize(b, length);
- b = null;
- } catch (ClassNotFoundException e) {
- throw new HiveException(e);
- } catch (IOException e) {
- throw new HiveException(e);
- }
- this._cachedModeId = modelId;
- this._cachedModel = model;
- }
-
- int numFeatures = model.getNumFeatures();
- int numFields = model.getNumFields();
-
- Object arg2 = args[2].get();
- // [workaround]
- // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray
- // cannot be cast to [Ljava.lang.Object;
- if (arg2 instanceof LazyBinaryArray) {
- arg2 = ((LazyBinaryArray) arg2).getList();
- }
- Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures,
- numFields);
- if (x == null || x.length == 0) {
- return null; // return NULL if there are no features
- }
- this._probes = x;
-
- double predicted = predict(x, model);
- _result.set(predicted);
- return _result;
- }
-
- private static double predict(@Nonnull final Feature[] x,
- @Nonnull final FFMPredictionModel model) throws HiveException {
- // w0
- double ret = model.getW0();
- // W
- for (Feature e : x) {
- double xi = e.getValue();
- float wi = model.getW(e);
- double wx = wi * xi;
- ret += wx;
- }
- // V
- final int factors = model.getNumFactors();
- final float[] vij = new float[factors];
- final float[] vji = new float[factors];
- for (int i = 0; i < x.length; ++i) {
- final Feature ei = x[i];
- final double xi = ei.getValue();
- final int iField = ei.getField();
- for (int j = i + 1; j < x.length; ++j) {
- final Feature ej = x[j];
- final double xj = ej.getValue();
- final int jField = ej.getField();
- if (!model.getV(ei, jField, vij)) {
- continue;
- }
- if (!model.getV(ej, iField, vji)) {
- continue;
- }
- for (int f = 0; f < factors; f++) {
- float vijf = vij[f];
- float vjif = vji[f];
- ret += vijf * vjif * xi * xj;
- }
- }
- }
- if (!NumberUtils.isFinite(ret)) {
- throw new HiveException("Detected " + ret + " in ffm_predict");
- }
- return ret;
- }
-
- @Override
- public void close() throws IOException {
- super.close();
- // clean up to help GC
- this._cachedModel = null;
- this._probes = null;
- }
-
- @Override
- public String getDisplayString(String[] args) {
- return "ffm_predict(" + Arrays.toString(args) + ")";
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FFMPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java
deleted file mode 100644
index befbec9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.codec.VariableByteCodec;
-import hivemall.utils.codec.ZigZagLEB128Codec;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-import hivemall.utils.collections.maps.IntOpenHashTable;
-import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;
-import hivemall.utils.io.IOUtils;
-import hivemall.utils.lang.ArrayUtils;
-import hivemall.utils.lang.HalfFloat;
-import hivemall.utils.lang.ObjectUtils;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-public final class FFMPredictionModel implements Externalizable {
- private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class);
-
- private static final byte HALF_FLOAT_ENTRY = 1;
- private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2;
- private static final byte FLOAT_ENTRY = 3;
- private static final byte W_ONLY_FLOAT_ENTRY = 4;
-
- /**
- * maps feature to feature weight pointer
- */
- private Int2LongOpenHashTable _map;
- private HeapBuffer _buf;
-
- private double _w0;
- private int _factors;
- private int _numFeatures;
- private int _numFields;
-
- public FFMPredictionModel() {}// for Externalizable
-
- public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf,
- double w0, int factor, int numFeatures, int numFields) {
- this._map = map;
- this._buf = buf;
- this._w0 = w0;
- this._factors = factor;
- this._numFeatures = numFeatures;
- this._numFields = numFields;
- }
-
- public int getNumFactors() {
- return _factors;
- }
-
- public double getW0() {
- return _w0;
- }
-
- public int getNumFeatures() {
- return _numFeatures;
- }
-
- public int getNumFields() {
- return _numFields;
- }
-
- public int getActualNumFeatures() {
- return _map.size();
- }
-
- public long approxBytesConsumed() {
- int size = _map.size();
-
- // [map] size * (|state| + |key| + |entry|)
- long bytes = size * (1L + 4L + 4L + (4L * _factors));
- int rest = _map.capacity() - size;
- if (rest > 0) {
- bytes += rest * 1L;
- }
- // w0, factors, numFeatures, numFields, used, size
- bytes += (8 + 4 + 4 + 4 + 4 + 4);
- return bytes;
- }
-
- @Nullable
- private Entry getEntry(final int key) {
- final long ptr = _map.get(key);
- if (ptr == -1L) {
- return null;
- }
- return new Entry(_buf, _factors, ptr);
- }
-
- public float getW(@Nonnull final Feature x) {
- int j = x.getFeatureIndex();
-
- Entry entry = getEntry(j);
- if (entry == null) {
- return 0.f;
- }
- return entry.getW();
- }
-
- /**
- * @return true if V exists
- */
- public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) {
- int j = Feature.toIntFeature(x, yField, _numFields);
-
- Entry entry = getEntry(j);
- if (entry == null) {
- return false;
- }
-
- entry.getV(dst);
- if (ArrayUtils.equals(dst, 0.f)) {
- return false; // treat as null
- }
- return true;
- }
-
- @Override
- public void writeExternal(@Nonnull ObjectOutput out) throws IOException {
- out.writeDouble(_w0);
- final int factors = _factors;
- out.writeInt(factors);
- out.writeInt(_numFeatures);
- out.writeInt(_numFields);
-
- int used = _map.size();
- out.writeInt(used);
-
- final int[] keys = _map.getKeys();
- final int size = keys.length;
- out.writeInt(size);
-
- final byte[] states = _map.getStates();
- writeStates(states, out);
-
- final long[] values = _map.getValues();
-
- final HeapBuffer buf = _buf;
- final Entry e = new Entry(buf, factors);
- final float[] Vf = new float[factors];
- for (int i = 0; i < size; i++) {
- if (states[i] != IntOpenHashTable.FULL) {
- continue;
- }
- ZigZagLEB128Codec.writeSignedInt(keys[i], out);
- e.setOffset(values[i]);
- writeEntry(e, factors, Vf, out);
- }
-
- // help GC
- this._map = null;
- this._buf = null;
- }
-
- private static void writeEntry(@Nonnull final Entry e, final int factors,
- @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException {
- final float W = e.getW();
- e.getV(Vf);
-
- if (ArrayUtils.almostEquals(Vf, 0.f)) {
- if (HalfFloat.isRepresentable(W)) {
- out.writeByte(W_ONLY_HALF_FLOAT_ENTRY);
- out.writeShort(HalfFloat.floatToHalfFloat(W));
- } else {
- out.writeByte(W_ONLY_FLOAT_ENTRY);
- out.writeFloat(W);
- }
- } else if (isRepresentableAsHalfFloat(W, Vf)) {
- out.writeByte(HALF_FLOAT_ENTRY);
- out.writeShort(HalfFloat.floatToHalfFloat(W));
- for (int i = 0; i < factors; i++) {
- out.writeShort(HalfFloat.floatToHalfFloat(Vf[i]));
- }
- } else {
- out.writeByte(FLOAT_ENTRY);
- out.writeFloat(W);
- IOUtils.writeFloats(Vf, factors, out);
- }
- }
-
- private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) {
- if (!HalfFloat.isRepresentable(W)) {
- return false;
- }
- for (float V : Vf) {
- if (!HalfFloat.isRepresentable(V)) {
- return false;
- }
- }
- return true;
- }
-
- @Nonnull
- static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out)
- throws IOException {
- // write empty states's indexes differentially
- final int size = status.length;
- int cardinarity = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != IntOpenHashTable.FULL) {
- cardinarity++;
- }
- }
- out.writeInt(cardinarity);
- if (cardinarity == 0) {
- return;
- }
- int prev = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != IntOpenHashTable.FULL) {
- int diff = i - prev;
- assert (diff >= 0);
- VariableByteCodec.encodeUnsignedInt(diff, out);
- prev = i;
- }
- }
- }
-
- @Override
- public void readExternal(@Nonnull final ObjectInput in) throws IOException,
- ClassNotFoundException {
- this._w0 = in.readDouble();
- final int factors = in.readInt();
- this._factors = factors;
- this._numFeatures = in.readInt();
- this._numFields = in.readInt();
-
- final int used = in.readInt();
- final int size = in.readInt();
- final int[] keys = new int[size];
- final long[] values = new long[size];
- final byte[] states = new byte[size];
- readStates(in, states);
-
- final int entrySize = Entry.sizeOf(factors);
- int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1;
- final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks);
- final Entry e = new Entry(buf, factors);
- final float[] Vf = new float[factors];
- for (int i = 0; i < size; i++) {
- if (states[i] != IntOpenHashTable.FULL) {
- continue;
- }
- keys[i] = ZigZagLEB128Codec.readSignedInt(in);
- long ptr = buf.allocate(entrySize);
- e.setOffset(ptr);
- readEntry(in, factors, Vf, e);
- values[i] = ptr;
- }
-
- this._map = new Int2LongOpenHashTable(keys, values, states, used);
- this._buf = buf;
- }
-
- @Nonnull
- private static void readEntry(@Nonnull final DataInput in, final int factors,
- @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException {
- final byte type = in.readByte();
- switch (type) {
- case HALF_FLOAT_ENTRY: {
- float W = HalfFloat.halfFloatToFloat(in.readShort());
- dst.setW(W);
- for (int i = 0; i < factors; i++) {
- Vf[i] = HalfFloat.halfFloatToFloat(in.readShort());
- }
- dst.setV(Vf);
- break;
- }
- case W_ONLY_HALF_FLOAT_ENTRY: {
- float W = HalfFloat.halfFloatToFloat(in.readShort());
- dst.setW(W);
- break;
- }
- case FLOAT_ENTRY: {
- float W = in.readFloat();
- dst.setW(W);
- IOUtils.readFloats(in, Vf);
- dst.setV(Vf);
- break;
- }
- case W_ONLY_FLOAT_ENTRY: {
- float W = in.readFloat();
- dst.setW(W);
- break;
- }
- default:
- throw new IOException("Unexpected Entry type: " + type);
- }
- }
-
- @Nonnull
- static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status)
- throws IOException {
- // read non-empty states differentially
- final int cardinarity = in.readInt();
- Arrays.fill(status, IntOpenHashTable.FULL);
- int prev = 0;
- for (int j = 0; j < cardinarity; j++) {
- int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
- status[i] = IntOpenHashTable.FREE;
- prev = i;
- }
- }
-
- public byte[] serialize() throws IOException {
- LOG.info("FFMPredictionModel#serialize(): " + _buf.toString());
- return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true);
- }
-
- public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len)
- throws ClassNotFoundException, IOException {
- FFMPredictionModel model = new FFMPredictionModel();
- ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2,
- true);
- LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString());
- return model;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
index 4f445fa..22b0541 100644
--- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
@@ -22,13 +22,20 @@ import hivemall.fm.Entry.AdaGradEntry;
import hivemall.fm.Entry.FTRLEntry;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
+import hivemall.utils.collections.lists.LongArrayList;
import hivemall.utils.collections.maps.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable.MapIterator;
import hivemall.utils.lang.NumberUtils;
-import hivemall.utils.math.MathUtils;
+import java.text.NumberFormat;
+import java.util.Locale;
+
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.roaringbitmap.RoaringBitmap;
+
public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel {
private static final int DEFAULT_MAPSIZE = 65536;
@@ -36,37 +43,55 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
private float _w0;
@Nonnull
private final Int2LongOpenHashTable _map;
+ @Nonnull
private final HeapBuffer _buf;
+ @Nonnull
+ private final LongArrayList _freelistW;
+ @Nonnull
+ private final LongArrayList _freelistV;
+
+ private boolean _initV;
+ @Nonnull
+ private RoaringBitmap _removedV;
+
// hyperparams
- private final int _numFeatures;
private final int _numFields;
- // FTEL
- private final float _alpha;
- private final float _beta;
- private final float _lambda1;
- private final float _lamdda2;
+ private final int _entrySizeW;
+ private final int _entrySizeV;
- private final int _entrySize;
+ // statistics
+ private long _bytesAllocated, _bytesUsed;
+ private int _numAllocatedW, _numReusedW, _numRemovedW;
+ private int _numAllocatedV, _numReusedV, _numRemovedV;
public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) {
super(params);
this._w0 = 0.f;
this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
- this._numFeatures = params.numFeatures;
+ this._freelistW = new LongArrayList();
+ this._freelistV = new LongArrayList();
+ this._initV = true;
+ this._removedV = new RoaringBitmap();
this._numFields = params.numFields;
- this._alpha = params.alphaFTRL;
- this._beta = params.betaFTRL;
- this._lambda1 = params.lambda1;
- this._lamdda2 = params.lamdda2;
- this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad);
+ this._entrySizeW = entrySize(1, _useFTRL, _useAdaGrad);
+ this._entrySizeV = entrySize(_factor, _useFTRL, _useAdaGrad);
}
- @Nonnull
- FFMPredictionModel toPredictionModel() {
- return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields);
+ private static int entrySize(@Nonnegative int factors, boolean ftrl, boolean adagrad) {
+ if (ftrl) {
+ return FTRLEntry.sizeOf(factors);
+ } else if (adagrad) {
+ return AdaGradEntry.sizeOf(factors);
+ } else {
+ return Entry.sizeOf(factors);
+ }
+ }
+
+ void disableInitV() {
+ this._initV = false;
}
@Override
@@ -86,7 +111,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
@Override
public float getW(@Nonnull final Feature x) {
- int j = x.getFeatureIndex();
+ int j = Feature.toIntFeature(x);
Entry entry = getEntry(j);
if (entry == null) {
@@ -97,12 +122,11 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
@Override
protected void setW(@Nonnull final Feature x, final float nextWi) {
- final int j = x.getFeatureIndex();
+ final int j = Feature.toIntFeature(x);
Entry entry = getEntry(j);
if (entry == null) {
- float[] V = initV();
- entry = newEntry(nextWi, V);
+ entry = newEntry(j, nextWi);
long ptr = entry.getOffset();
_map.put(j, ptr);
} else {
@@ -110,53 +134,6 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
}
}
- @Override
- void updateWi(final double dloss, @Nonnull final Feature x, final float eta) {
- final double Xi = x.getValue();
- float gradWi = (float) (dloss * Xi);
-
- final Entry theta = getEntry(x);
- float wi = theta.getW();
-
- float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
- if (!NumberUtils.isFinite(nextWi)) {
- throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
- + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
- + ", eta=" + eta);
- }
- theta.setW(nextWi);
- }
-
- /**
- * Update Wi using Follow-the-Regularized-Leader
- */
- boolean updateWiFTRL(final double dloss, @Nonnull final Feature x, final float eta) {
- final double Xi = x.getValue();
- float gradWi = (float) (dloss * Xi);
-
- final Entry theta = getEntry(x);
- float wi = theta.getW();
-
- final float z = theta.updateZ(gradWi, _alpha);
- final double n = theta.updateN(gradWi);
-
- if (Math.abs(z) <= _lambda1) {
- removeEntry(x);
- return wi != 0;
- }
-
- final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n))
- / _alpha + _lamdda2));
- if (!NumberUtils.isFinite(nextWi)) {
- throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
- + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
- + ", eta=" + eta + ", n=" + n + ", z=" + z);
- }
- theta.setW(nextWi);
- return (nextWi != 0) || (wi != 0);
- }
-
-
/**
* @return V_x,yField,f
*/
@@ -166,10 +143,16 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
Entry entry = getEntry(j);
if (entry == null) {
+ if (_initV == false) {
+ return 0.f;
+ } else if (_removedV.contains(j)) {
+ return 0.f;
+ }
float[] V = initV();
- entry = newEntry(V);
+ entry = newEntry(j, V);
long ptr = entry.getOffset();
_map.put(j, ptr);
+ return V[f];
}
return entry.getV(f);
}
@@ -181,8 +164,13 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
Entry entry = getEntry(j);
if (entry == null) {
+ if (_initV == false) {
+ return;
+ } else if (_removedV.contains(j)) {
+ return;
+ }
float[] V = initV();
- entry = newEntry(V);
+ entry = newEntry(j, V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
@@ -190,13 +178,12 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
}
@Override
- protected Entry getEntry(@Nonnull final Feature x) {
- final int j = x.getFeatureIndex();
+ protected Entry getEntryW(@Nonnull final Feature x) {
+ final int j = Feature.toIntFeature(x);
Entry entry = getEntry(j);
if (entry == null) {
- float[] V = initV();
- entry = newEntry(V);
+ entry = newEntry(j, 0.f);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
@@ -204,51 +191,92 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
}
@Override
- protected Entry getEntry(@Nonnull final Feature x, @Nonnull final int yField) {
+ protected Entry getEntryV(@Nonnull final Feature x, @Nonnull final int yField) {
final int j = Feature.toIntFeature(x, yField, _numFields);
Entry entry = getEntry(j);
if (entry == null) {
+ if (_initV == false) {
+ return null;
+ } else if (_removedV.contains(j)) {
+ return null;
+ }
float[] V = initV();
- entry = newEntry(V);
+ entry = newEntry(j, V);
long ptr = entry.getOffset();
_map.put(j, ptr);
}
return entry;
}
- protected void removeEntry(@Nonnull final Feature x) {
- int j = x.getFeatureIndex();
- _map.remove(j);
+ @Override
+ protected void removeEntry(@Nonnull final Entry entry) {
+ final int j = entry.getKey();
+ final long ptr = _map.remove(j);
+ if (ptr == -1L) {
+ return; // should never be happen.
+ }
+ entry.clear();
+ if (Entry.isEntryW(j)) {
+ _freelistW.add(ptr);
+ this._numRemovedW++;
+ this._bytesUsed -= _entrySizeW;
+ } else {
+ _removedV.add(j);
+ _freelistV.add(ptr);
+ this._numRemovedV++;
+ this._bytesUsed -= _entrySizeV;
+ }
}
@Nonnull
- protected final Entry newEntry(final float W, @Nonnull final float[] V) {
- Entry entry = newEntry();
- entry.setW(W);
- entry.setV(V);
- return entry;
- }
+ protected final Entry newEntry(final int key, final float W) {
+ final long ptr;
+ if (_freelistW.isEmpty()) {
+ ptr = _buf.allocate(_entrySizeW);
+ this._numAllocatedW++;
+ this._bytesAllocated += _entrySizeW;
+ this._bytesUsed += _entrySizeW;
+ } else {// reuse removed entry
+ ptr = _freelistW.remove();
+ this._numReusedW++;
+ }
+ final Entry entry;
+ if (_useFTRL) {
+ entry = new FTRLEntry(_buf, key, ptr);
+ } else if (_useAdaGrad) {
+ entry = new AdaGradEntry(_buf, key, ptr);
+ } else {
+ entry = new Entry(_buf, key, ptr);
+ }
- @Nonnull
- protected final Entry newEntry(@Nonnull final float[] V) {
- Entry entry = newEntry();
- entry.setV(V);
+ entry.setW(W);
return entry;
}
@Nonnull
- private Entry newEntry() {
+ protected final Entry newEntry(final int key, @Nonnull final float[] V) {
+ final long ptr;
+ if (_freelistV.isEmpty()) {
+ ptr = _buf.allocate(_entrySizeV);
+ this._numAllocatedV++;
+ this._bytesAllocated += _entrySizeV;
+ this._bytesUsed += _entrySizeV;
+ } else {// reuse removed entry
+ ptr = _freelistV.remove();
+ this._numReusedV++;
+ }
+ final Entry entry;
if (_useFTRL) {
- long ptr = _buf.allocate(_entrySize);
- return new FTRLEntry(_buf, _factor, ptr);
+ entry = new FTRLEntry(_buf, _factor, key, ptr);
} else if (_useAdaGrad) {
- long ptr = _buf.allocate(_entrySize);
- return new AdaGradEntry(_buf, _factor, ptr);
+ entry = new AdaGradEntry(_buf, _factor, key, ptr);
} else {
- long ptr = _buf.allocate(_entrySize);
- return new Entry(_buf, _factor, ptr);
+ entry = new Entry(_buf, _factor, key, ptr);
}
+
+ entry.setV(V);
+ return entry;
}
@Nullable
@@ -257,28 +285,95 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
if (ptr == -1L) {
return null;
}
- return getEntry(ptr);
+ return getEntry(key, ptr);
}
@Nonnull
- private Entry getEntry(long ptr) {
- if (_useFTRL) {
- return new FTRLEntry(_buf, _factor, ptr);
- } else if (_useAdaGrad) {
- return new AdaGradEntry(_buf, _factor, ptr);
+ private Entry getEntry(final int key, @Nonnegative final long ptr) {
+ if (Entry.isEntryW(key)) {
+ if (_useFTRL) {
+ return new FTRLEntry(_buf, key, ptr);
+ } else if (_useAdaGrad) {
+ return new AdaGradEntry(_buf, key, ptr);
+ } else {
+ return new Entry(_buf, key, ptr);
+ }
} else {
- return new Entry(_buf, _factor, ptr);
+ if (_useFTRL) {
+ return new FTRLEntry(_buf, _factor, key, ptr);
+ } else if (_useAdaGrad) {
+ return new AdaGradEntry(_buf, _factor, key, ptr);
+ } else {
+ return new Entry(_buf, _factor, key, ptr);
+ }
}
}
- private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
- if (ftrl) {
- return FTRLEntry.sizeOf(factors);
- } else if (adagrad) {
- return AdaGradEntry.sizeOf(factors);
- } else {
- return Entry.sizeOf(factors);
+ @Nonnull
+ String getStatistics() {
+ final NumberFormat fmt = NumberFormat.getIntegerInstance(Locale.US);
+ return "FFMStringFeatureMapModel [bytesAllocated="
+ + NumberUtils.prettySize(_bytesAllocated) + ", bytesUsed="
+ + NumberUtils.prettySize(_bytesUsed) + ", numAllocatedW="
+ + fmt.format(_numAllocatedW) + ", numReusedW=" + fmt.format(_numReusedW)
+ + ", numRemovedW=" + fmt.format(_numRemovedW) + ", numAllocatedV="
+ + fmt.format(_numAllocatedV) + ", numReusedV=" + fmt.format(_numReusedV)
+ + ", numRemovedV=" + fmt.format(_numRemovedV) + "]";
+ }
+
+ @Override
+ public String toString() {
+ return getStatistics();
+ }
+
+ @Nonnull
+ EntryIterator entries() {
+ return new EntryIterator(this);
+ }
+
+ static final class EntryIterator {
+
+ @Nonnull
+ private final MapIterator dictItor;
+ @Nonnull
+ private final Entry entryProbeW;
+ @Nonnull
+ private final Entry entryProbeV;
+
+ EntryIterator(@Nonnull FFMStringFeatureMapModel model) {
+ this.dictItor = model._map.entries();
+ this.entryProbeW = new Entry(model._buf, 1);
+ this.entryProbeV = new Entry(model._buf, model._factor);
+ }
+
+ @Nonnull
+ Entry getEntryProbeW() {
+ return entryProbeW;
}
+
+ @Nonnull
+ Entry getEntryProbeV() {
+ return entryProbeV;
+ }
+
+ boolean hasNext() {
+ return dictItor.hasNext();
+ }
+
+ boolean next() {
+ return dictItor.next() != -1;
+ }
+
+ int getEntryIndex() {
+ return dictItor.getKey();
+ }
+
+ @Nonnull
+ void getEntry(@Nonnull final Entry probe) {
+ long offset = dictItor.getValue();
+ probe.setOffset(offset);
+ }
+
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FMHyperParameters.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java
index accb99a..15c1c56 100644
--- a/core/src/main/java/hivemall/fm/FMHyperParameters.java
+++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java
@@ -143,16 +143,15 @@ class FMHyperParameters {
int numFields = Feature.DEFAULT_NUM_FIELDS;
// adagrad
- boolean useAdaGrad = true;
- float eta0_V = 1.f;
+ boolean useAdaGrad = false;
float eps = 1.f;
// FTRL
- boolean useFTRL = true;
- float alphaFTRL = 0.1f; // Learning Rate
+ boolean useFTRL = false;
+ float alphaFTRL = 0.2f; // Learning Rate
float betaFTRL = 1.f; // Smoothing parameter for AdaGrad
- float lambda1 = 0.1f; // L1 Regularization
- float lamdda2 = 0.01f; // L2 Regularization
+ float lambda1 = 0.001f; // L1 Regularization
+ float lamdda2 = 0.0001f; // L2 Regularization
FFMHyperParameters() {
super();
@@ -171,42 +170,59 @@ class FMHyperParameters {
// feature hashing
if (numFeatures == -1) {
- int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"),
- Feature.DEFAULT_FEATURE_BITS);
- if (hashbits < 18 || hashbits > 31) {
- throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: "
- + hashbits);
+ int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1);
+ if (hashbits != -1) {
+ if (hashbits < 18 || hashbits > 31) {
+ throw new UDFArgumentException(
+ "-feature_hashing MUST be in range [18,31]: " + hashbits);
+ }
+ this.numFeatures = 1 << hashbits;
}
- this.numFeatures = 1 << hashbits;
}
this.numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), numFields);
if (numFields <= 1) {
throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields);
}
- // adagrad
- this.useAdaGrad = !cl.hasOption("disable_adagrad");
- this.eta0_V = Primitives.parseFloat(cl.getOptionValue("eta0_V"), eta0_V);
- this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps);
-
- // FTRL
- this.useFTRL = !cl.hasOption("disable_ftrl");
- this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), alphaFTRL);
- if (alphaFTRL == 0.f) {
- throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0");
+ // optimizer
+ final String optimizer = cl.getOptionValue("optimizer", "ftrl").toLowerCase();
+ switch (optimizer) {
+ case "ftrl": {
+ this.useFTRL = true;
+ this.useAdaGrad = false;
+ this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"),
+ alphaFTRL);
+ if (alphaFTRL == 0.f) {
+ throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0");
+ }
+ this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL);
+ this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1);
+ this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2);
+ break;
+ }
+ case "adagrad": {
+ this.useAdaGrad = true;
+ this.useFTRL = false;
+ this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps);
+ break;
+ }
+ case "sgd":
+ // fall through
+ default: {
+ this.useFTRL = false;
+ this.useAdaGrad = false;
+ break;
+ }
}
- this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL);
- this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1);
- this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2);
}
@Override
public String toString() {
return "FFMHyperParameters [globalBias=" + globalBias + ", linearCoeff=" + linearCoeff
- + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eta0_V="
- + eta0_V + ", eps=" + eps + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL
- + ", betaFTRL=" + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2
- + "], " + super.toString();
+ + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eps=" + eps
+ + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL + ", betaFTRL="
+ + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2 + "], "
+ + super.toString();
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
index 19ac287..be39b0b 100644
--- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
@@ -19,7 +19,7 @@
package hivemall.fm;
import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
-import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashTable;
import java.util.Arrays;
@@ -33,7 +33,7 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel {
// LEARNING PARAMS
private float _w0;
private final Int2FloatOpenHashTable _w;
- private final IntOpenHashMap<float[]> _V;
+ private final IntOpenHashTable<float[]> _V;
private int _minIndex, _maxIndex;
@@ -42,7 +42,7 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel {
this._w0 = 0.f;
this._w = new Int2FloatOpenHashTable(DEFAULT_MAPSIZE);
_w.defaultReturnValue(0.f);
- this._V = new IntOpenHashMap<float[]>(DEFAULT_MAPSIZE);
+ this._V = new IntOpenHashTable<float[]>(DEFAULT_MAPSIZE);
this._minIndex = 0;
this._maxIndex = 0;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java
index 667befb..730cc49 100644
--- a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java
+++ b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java
@@ -18,6 +18,9 @@
*/
package hivemall.fm;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
@@ -35,6 +38,7 @@ import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
@@ -234,6 +238,7 @@ public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver {
}
+ @AggregationType(estimable = true)
public static class FMPredictAggregationBuffer extends AbstractAggregationBuffer {
private double ret;
@@ -328,6 +333,16 @@ public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver {
}
return predict;
}
+
+ @Override
+ public int estimate() {
+ if (sumVjXj == null) {
+ return PRIMITIVES2 + 2 * JAVA64_REF;
+ } else {
+ // model.array() = JAVA64_ARRAY_META + JAVA64_REF
+ return PRIMITIVES2 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * sumVjXj.length);
+ }
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
index cd99046..4eec280 100644
--- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
@@ -19,7 +19,7 @@
package hivemall.fm;
import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
@@ -28,12 +28,12 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel {
// LEARNING PARAMS
private float _w0;
- private final OpenHashTable<String, Entry> _map;
+ private final OpenHashMap<String, Entry> _map;
public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) {
super(params);
this._w0 = 0.f;
- this._map = new OpenHashTable<String, FMStringFeatureMapModel.Entry>(DEFAULT_MAPSIZE);
+ this._map = new OpenHashMap<String, FMStringFeatureMapModel.Entry>(DEFAULT_MAPSIZE);
}
@Override
@@ -42,7 +42,7 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel {
}
IMapIterator<String, Entry> entries() {
- return _map.entries();
+ return _map.entries(true);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
index 65b6ba7..24210a8 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
@@ -117,8 +117,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
opts.addOption("c", "classification", false, "Act as classification");
opts.addOption("seed", true, "Seed value [default: -1 (random)]");
opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
- opts.addOption("p", "num_features", true, "The size of feature dimensions");
- opts.addOption("factor", "factors", true, "The number of the latent variables [default: 5]");
+ opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
+ opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]");
opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]");
opts.addOption("lambda0", "lambda", true,
"The initial lambda value for regularization [default: 0.01]");
@@ -376,7 +376,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
double loss = _lossFunction.loss(p, y);
_cvState.incrLoss(loss);
- if (MathUtils.closeToZero(lossGrad)) {
+ if (MathUtils.closeToZero(lossGrad, 1E-9d)) {
return;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/38047891/core/src/main/java/hivemall/fm/Feature.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java
index 2966a02..8ae6f20 100644
--- a/core/src/main/java/hivemall/fm/Feature.java
+++ b/core/src/main/java/hivemall/fm/Feature.java
@@ -23,6 +23,7 @@ import hivemall.utils.lang.NumberUtils;
import java.nio.ByteBuffer;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -30,7 +31,7 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
public abstract class Feature {
- public static final int DEFAULT_NUM_FIELDS = 1024;
+ public static final int DEFAULT_NUM_FIELDS = 256;
public static final int DEFAULT_FEATURE_BITS = 21;
public static final int DEFAULT_NUM_FEATURES = 1 << 21; // 2^21
@@ -51,10 +52,11 @@ public abstract class Feature {
throw new UnsupportedOperationException();
}
- public void setFeatureIndex(int i) {
+ public void setFeatureIndex(@Nonnegative int i) {
throw new UnsupportedOperationException();
}
+ @Nonnegative
public int getFeatureIndex() {
throw new UnsupportedOperationException();
}
@@ -127,6 +129,7 @@ public abstract class Feature {
}
}
+ @Nullable
public static Feature[] parseFFMFeatures(@Nonnull final Object arg,
@Nonnull final ListObjectInspector listOI, @Nullable final Feature[] probes,
final int numFeatures, final int numFields) throws HiveException {
@@ -176,6 +179,9 @@ public abstract class Feature {
int index = parseFeatureIndex(fv);
return new IntFeature(index, 1.d);
} else {
+ if ("0".equals(fv)) {
+ throw new HiveException("Index value should not be 0: " + fv);
+ }
return new StringFeature(/* index */fv, 1.d);
}
} else {
@@ -187,6 +193,9 @@ public abstract class Feature {
return new IntFeature(index, value);
} else {
double value = parseFeatureValue(valueStr);
+ if ("0".equals(indexStr)) {
+ throw new HiveException("Index value should not be 0: " + fv);
+ }
return new StringFeature(/* index */indexStr, value);
}
}
@@ -198,6 +207,12 @@ public abstract class Feature {
}
@Nonnull
+ static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures)
+ throws HiveException {
+ return parseFFMFeature(fv, -1, DEFAULT_NUM_FIELDS);
+ }
+
+ @Nonnull
static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures,
final int numFields) throws HiveException {
final int pos1 = fv.indexOf(':');
@@ -219,25 +234,26 @@ public abstract class Feature {
} else {
index = MurmurHash3.murmurhash3(lead, numFields);
}
- short field = (short) index;
+ short field = NumberUtils.castToShort(index);
double value = parseFeatureValue(rest);
return new IntFeature(index, field, value);
}
- final String indexStr = rest.substring(0, pos2);
- final int index;
+
final short field;
- if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) {
- index = parseFeatureIndex(indexStr);
- if (index >= (numFeatures + numFields)) {
- throw new HiveException("Feature index MUST be less than "
- + (numFeatures + numFields) + " but was " + index);
- }
+ if (NumberUtils.isDigits(lead)) {
field = parseField(lead, numFields);
} else {
+ field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields));
+ }
+
+ final int index;
+ final String indexStr = rest.substring(0, pos2);
+ if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) {
+ index = parseFeatureIndex(indexStr);
+ } else {
// +NUM_FIELD to avoid conflict to quantitative features
index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields;
- field = (short) MurmurHash3.murmurhash3(lead, numFields);
}
String valueStr = rest.substring(pos2 + 1);
double value = parseFeatureValue(valueStr);
@@ -253,6 +269,9 @@ public abstract class Feature {
int index = parseFeatureIndex(fv);
probe.setFeatureIndex(index);
} else {
+ if ("0".equals(fv)) {
+ throw new HiveException("Index value should not be 0: " + fv);
+ }
probe.setFeature(fv);
}
probe.value = 1.d;
@@ -264,6 +283,9 @@ public abstract class Feature {
probe.setFeatureIndex(index);
probe.value = parseFeatureValue(valueStr);
} else {
+ if ("0".equals(indexStr)) {
+ throw new HiveException("Index value should not be 0: " + fv);
+ }
probe.setFeature(indexStr);
probe.value = parseFeatureValue(valueStr);
}
@@ -296,27 +318,26 @@ public abstract class Feature {
} else {
index = MurmurHash3.murmurhash3(lead, numFields);
}
- short field = (short) index;
+ short field = NumberUtils.castToShort(index);
probe.setField(field);
probe.setFeatureIndex(index);
probe.value = parseFeatureValue(rest);
return;
}
- String indexStr = rest.substring(0, pos2);
- final int index;
final short field;
- if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) {
- index = parseFeatureIndex(indexStr);
- if (index >= (numFeatures + numFields)) {
- throw new HiveException("Feature index MUST be less than "
- + (numFeatures + numFields) + " but was " + index);
- }
+ if (NumberUtils.isDigits(lead)) {
field = parseField(lead, numFields);
} else {
+ field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields));
+ }
+ final int index;
+ final String indexStr = rest.substring(0, pos2);
+ if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) {
+ index = parseFeatureIndex(indexStr);
+ } else {
// +NUM_FIELD to avoid conflict to quantitative features
index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields;
- field = (short) MurmurHash3.murmurhash3(lead, numFields);
}
probe.setField(field);
probe.setFeatureIndex(index);
@@ -325,7 +346,6 @@ public abstract class Feature {
probe.value = parseFeatureValue(valueStr);
}
-
private static int parseFeatureIndex(@Nonnull final String indexStr) throws HiveException {
final int index;
try {
@@ -333,7 +353,7 @@ public abstract class Feature {
} catch (NumberFormatException e) {
throw new HiveException("Invalid index value: " + indexStr, e);
}
- if (index < 0) {
+ if (index <= 0) {
throw new HiveException("Feature index MUST be greater than 0: " + indexStr);
}
return index;
@@ -361,7 +381,13 @@ public abstract class Feature {
return field;
}
- public static int toIntFeature(@Nonnull final Feature x, final int yField, final int numFields) {
+ public static int toIntFeature(@Nonnull final Feature x) {
+ int index = x.getFeatureIndex();
+ return -index;
+ }
+
+ public static int toIntFeature(@Nonnull final Feature x, @Nonnegative final int yField,
+ @Nonnegative final int numFields) {
int index = x.getFeatureIndex();
return index * numFields + yField;
}