You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/11 05:36:46 UTC
[3/4] incubator-hivemall git commit: Close #105: [HIVEMALL-24-2] Make
ffm_predict function more scalable by creating its UDAF implementation
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
index 76bead8..730d0f4 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
@@ -22,9 +22,11 @@ import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -34,19 +36,33 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
@Nonnull
protected final FFMHyperParameters _params;
- protected final float _eta0_V;
+ protected final float _eta0;
protected final float _eps;
protected final boolean _useAdaGrad;
protected final boolean _useFTRL;
+ // FTEL
+ private final float _alpha;
+ private final float _beta;
+ private final float _lambda1;
+ private final float _lamdda2;
+
public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) {
super(params);
this._params = params;
- this._eta0_V = params.eta0_V;
+ if (params.useAdaGrad) {
+ this._eta0 = 1.0f;
+ } else {
+ this._eta0 = params.eta.eta0();
+ }
this._eps = params.eps;
this._useAdaGrad = params.useAdaGrad;
this._useFTRL = params.useFTRL;
+ this._alpha = params.alphaFTRL;
+ this._beta = params.betaFTRL;
+ this._lambda1 = params.lambda1;
+ this._lamdda2 = params.lamdda2;
}
public abstract float getV(@Nonnull Feature x, @Nonnull int yField, int f);
@@ -100,31 +116,152 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
return ret;
}
+ void updateWi(final double dloss, @Nonnull final Feature x, final long t) {
+ if (_useFTRL) {
+ updateWi_FTRL(dloss, x);
+ return;
+ }
+
+ final double Xi = x.getValue();
+ float gradWi = (float) (dloss * Xi);
+
+ final Entry theta = getEntryW(x);
+ float wi = theta.getW();
+
+ final float eta = eta(theta, t, gradWi);
+ float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi);
+ if (!NumberUtils.isFinite(nextWi)) {
+ throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss
+ + ", eta=" + eta + ", t=" + t);
+ }
+ if (MathUtils.closeToZero(nextWi, 1E-9f)) {
+ removeEntry(theta);
+ return;
+ }
+ theta.setW(nextWi);
+ }
+
+ /**
+ * Update Wi using Follow-the-Regularized-Leader
+ */
+ private void updateWi_FTRL(final double dloss, @Nonnull final Feature x) {
+ final double Xi = x.getValue();
+ float gradWi = (float) (dloss * Xi);
+
+ final Entry theta = getEntryW(x);
+
+ final float z = theta.updateZ(gradWi, _alpha);
+ final double n = theta.updateN(gradWi);
+
+ if (Math.abs(z) <= _lambda1) {
+ removeEntry(theta);
+ return;
+ }
+
+ final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n))
+ / _alpha + _lamdda2));
+ if (!NumberUtils.isFinite(nextWi)) {
+ throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature()
+ + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + theta.getW()
+ + ", dloss=" + dloss + ", n=" + n + ", z=" + z);
+ }
+ if (MathUtils.closeToZero(nextWi, 1E-9f)) {
+ removeEntry(theta);
+ return;
+ }
+ theta.setW(nextWi);
+ }
+
+ protected abstract void removeEntry(@Nonnull final Entry entry);
+
void updateV(final double dloss, @Nonnull final Feature x, @Nonnull final int yField,
final int f, final double sumViX, long t) {
+ if (_useFTRL) {
+ updateV_FTRL(dloss, x, yField, f, sumViX);
+ return;
+ }
+
+ final Entry theta = getEntryV(x, yField);
+ if (theta == null) {
+ return;
+ }
+
final double Xi = x.getValue();
final double h = Xi * sumViX;
final float gradV = (float) (dloss * h);
final float lambdaVf = getLambdaV(f);
- final Entry theta = getEntry(x, yField);
final float currentV = theta.getV(f);
- final float eta = etaV(theta, t, gradV);
+ final float eta = eta(theta, f, t, gradV);
final float nextV = currentV - eta * (gradV + 2.f * lambdaVf * currentV);
if (!NumberUtils.isFinite(nextV)) {
throw new IllegalStateException("Got " + nextV + " for next V" + f + '['
+ x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + currentV + ", h=" + h
+ ", gradV=" + gradV + ", lambdaVf=" + lambdaVf + ", dloss=" + dloss
- + ", sumViX=" + sumViX);
+ + ", sumViX=" + sumViX + ", t=" + t);
+ }
+ if (MathUtils.closeToZero(nextV, 1E-9f)) {
+ theta.setV(f, 0.f);
+ if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
+ removeEntry(theta);
+ }
+ return;
+ }
+ theta.setV(f, nextV);
+ }
+
+ private void updateV_FTRL(final double dloss, @Nonnull final Feature x,
+ @Nonnull final int yField, final int f, final double sumViX) {
+ final Entry theta = getEntryV(x, yField);
+ if (theta == null) {
+ return;
+ }
+
+ final double Xi = x.getValue();
+ final double h = Xi * sumViX;
+ final float gradV = (float) (dloss * h);
+
+ float oldV = theta.getV(f);
+ final float z = theta.updateZ(f, oldV, gradV, _alpha);
+ final double n = theta.updateN(f, gradV);
+
+ if (Math.abs(z) <= _lambda1) {
+ theta.setV(f, 0.f);
+ if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
+ removeEntry(theta);
+ }
+ return;
+ }
+
+ final float nextV = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n))
+ / _alpha + _lamdda2));
+ if (!NumberUtils.isFinite(nextV)) {
+ throw new IllegalStateException("Got " + nextV + " for next V" + f + '['
+ + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + theta.getV(f) + ", h="
+ + h + ", gradV=" + gradV + ", dloss=" + dloss + ", sumViX=" + sumViX + ", n="
+ + n + ", z=" + z);
+ }
+ if (MathUtils.closeToZero(nextV, 1E-9f)) {
+ theta.setV(f, 0.f);
+ if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled
+ removeEntry(theta);
+ }
+ return;
}
theta.setV(f, nextV);
}
- protected final float etaV(@Nonnull final Entry theta, final long t, final float grad) {
+ protected final float eta(@Nonnull final Entry theta, final long t, final float grad) {
+ return eta(theta, 0, t, grad);
+ }
+
+ protected final float eta(@Nonnull final Entry theta, @Nonnegative final int f, final long t,
+ final float grad) {
if (_useAdaGrad) {
- double gg = theta.getSumOfSquaredGradientsV();
- theta.addGradientV(grad);
- return (float) (_eta0_V / Math.sqrt(_eps + gg));
+ double gg = theta.getSumOfSquaredGradients(f);
+ theta.addGradient(f, grad);
+ return (float) (_eta0 / Math.sqrt(_eps + gg));
} else {
return _eta.eta(t);
}
@@ -187,10 +324,10 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
}
@Nonnull
- protected abstract Entry getEntry(@Nonnull Feature x);
+ protected abstract Entry getEntryW(@Nonnull Feature x);
- @Nonnull
- protected abstract Entry getEntry(@Nonnull Feature x, @Nonnull int yField);
+ @Nullable
+ protected abstract Entry getEntryV(@Nonnull Feature x, @Nonnull int yField);
@Override
protected final String varDump(@Nonnull final Feature[] x) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index 67dbf87..56d9dc2 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -18,17 +18,18 @@
*/
package hivemall.fm;
+import hivemall.fm.FFMStringFeatureMapModel.EntryIterator;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
-import hivemall.utils.hadoop.Text3;
-import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
-import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -44,6 +45,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
/**
@@ -60,8 +63,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
// ----------------------------------------
// Learning hyper-parameters/options
- private boolean _FTRL;
-
private boolean _globalBias;
private boolean _linearCoeff;
@@ -87,26 +88,25 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
opts.addOption("disable_wi", "no_coeff", false, "Not to include linear term [default: OFF]");
// feature hashing
opts.addOption("feature_hashing", true,
- "The number of bits for feature hashing in range [18,31] [default:21]");
- opts.addOption("num_fields", true, "The number of fields [default:1024]");
+ "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1.");
+ opts.addOption("num_fields", true, "The number of fields [default: 256]");
+ // optimizer
+ opts.addOption("opt", "optimizer", true,
+ "Gradient Descent optimizer [default: ftrl, adagrad, sgd]");
// adagrad
- opts.addOption("disable_adagrad", false,
- "Whether to use AdaGrad for tuning learning rate [default: ON]");
- opts.addOption("eta0_V", true, "The initial learning rate for V [default 1.0]");
- opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
+ opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]");
// FTRL
- opts.addOption("disable_ftrl", false,
- "Whether not to use Follow-The-Regularized-Reader [default: OFF]");
opts.addOption("alpha", "alphaFTRL", true,
- "Alpha value (learning rate) of Follow-The-Regularized-Reader [default 0.1]");
+ "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.2]");
opts.addOption("beta", "betaFTRL", true,
- "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default 1.0]");
+ "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]");
opts.addOption(
+ "l1",
"lambda1",
true,
- "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default 0.1]");
- opts.addOption("lambda2", true,
- "L2 regularization value of Follow-The-Regularized-Reader [default 0.01]");
+ "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.001]");
+ opts.addOption("l2", "lambda2", true,
+ "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]");
return opts;
}
@@ -125,7 +125,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
CommandLine cl = super.processOptions(argOIs);
FFMHyperParameters params = (FFMHyperParameters) _params;
- this._FTRL = params.useFTRL;
this._globalBias = params.globalBias;
this._linearCoeff = params.linearCoeff;
this._numFeatures = params.numFeatures;
@@ -150,8 +149,14 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ fieldNames.add("i");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+ fieldNames.add("Wi");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ fieldNames.add("Vi");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@@ -184,20 +189,19 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
@Override
protected void trainTheta(@Nonnull final Feature[] x, final double y) throws HiveException {
- final float eta_t = _etaEstimator.eta(_t);
-
final double p = _ffmModel.predict(x);
final double lossGrad = _ffmModel.dloss(p, y);
double loss = _lossFunction.loss(p, y);
_cvState.incrLoss(loss);
- if (MathUtils.closeToZero(lossGrad)) {
+ if (MathUtils.closeToZero(lossGrad, 1E-9d)) {
return;
}
// w0 update
if (_globalBias) {
+ float eta_t = _etaEstimator.eta(_t);
_ffmModel.updateW0(lossGrad, eta_t);
}
@@ -210,14 +214,16 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
if (x_i.value == 0.f) {
continue;
}
- boolean useV = updateWi(lossGrad, x_i, eta_t); // wi update
- if (useV == false) {
- continue;
+ if (_linearCoeff) {
+ _ffmModel.updateWi(lossGrad, x_i, _t);// wi update
}
for (int fieldIndex = 0, size = fieldList.size(); fieldIndex < size; fieldIndex++) {
final int yField = fieldList.get(fieldIndex);
for (int f = 0, k = _factors; f < k; f++) {
- double sumViX = sumVfX.get(i, fieldIndex, f);
+ final double sumViX = sumVfX.get(i, fieldIndex, f);
+ if (MathUtils.closeToZero(sumViX)) {// grad will be 0 => skip it
+ continue;
+ }
_ffmModel.updateV(lossGrad, x_i, yField, f, sumViX, _t);
}
}
@@ -229,18 +235,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
fieldList.clear();
}
- private boolean updateWi(double lossGrad, @Nonnull Feature xi, float eta) {
- if (!_linearCoeff) {
- return true;
- }
- if (_FTRL) {
- return _ffmModel.updateWiFTRL(lossGrad, xi, eta);
- } else {
- _ffmModel.updateWi(lossGrad, xi, eta);
- return true;
- }
- }
-
@Nonnull
private IntArrayList getFieldList(@Nonnull final Feature[] x) {
for (Feature e : x) {
@@ -257,7 +251,16 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
@Override
public void close() throws HiveException {
+ if (LOG.isInfoEnabled()) {
+ LOG.info(_ffmModel.getStatistics());
+ }
+
+ _ffmModel.disableInitV(); // trick to avoid re-instantiating removed (zero-filled) entry of V
super.close();
+
+ if (LOG.isInfoEnabled()) {
+ LOG.info(_ffmModel.getStatistics());
+ }
this._ffmModel = null;
}
@@ -267,39 +270,54 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
this._fieldList = null;
this._sumVfX = null;
- Text modelId = new Text();
- String taskId = HadoopUtils.getUniqueTaskIdString();
- modelId.set(taskId);
-
- FFMPredictionModel predModel = _ffmModel.toPredictionModel();
- this._ffmModel = null; // help GC
-
- if (LOG.isInfoEnabled()) {
- LOG.info("Serializing a model '" + modelId + "'... Configured # features: "
- + _numFeatures + ", Configured # fields: " + _numFields
- + ", Actual # features: " + predModel.getActualNumFeatures()
- + ", Estimated uncompressed bytes: "
- + NumberUtils.prettySize(predModel.approxBytesConsumed()));
- }
+ final int factors = _factors;
+ final IntWritable idx = new IntWritable();
+ final FloatWritable Wi = new FloatWritable(0.f);
+ final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f);
+ final List<FloatWritable> ViObj = Arrays.asList(Vi);
+
+ final Object[] forwardObjs = new Object[4];
+ String modelId = HadoopUtils.getUniqueTaskIdString();
+ forwardObjs[0] = new Text(modelId);
+ forwardObjs[1] = idx;
+ forwardObjs[2] = Wi;
+ forwardObjs[3] = null; // Vi
+
+ // W0
+ idx.set(0);
+ Wi.set(_ffmModel.getW0());
+ forward(forwardObjs);
- byte[] serialized;
- try {
- serialized = predModel.serialize();
- predModel = null;
- } catch (IOException e) {
- throw new HiveException("Failed to serialize a model", e);
- }
+ final EntryIterator itor = _ffmModel.entries();
+ final Entry entryW = itor.getEntryProbeW();
+ final Entry entryV = itor.getEntryProbeV();
+ final float[] Vf = new float[factors];
+ while (itor.next()) {
+ // set i
+ int i = itor.getEntryIndex();
+ idx.set(i);
+
+ if (Entry.isEntryW(i)) {// set Wi
+ itor.getEntry(entryW);
+ float w = entryV.getW();
+ if (w == 0.f) {
+ continue; // skip w_i=0
+ }
+ Wi.set(w);
+ forwardObjs[2] = Wi;
+ forwardObjs[3] = null;
+ } else {// set Vif
+ itor.getEntry(entryV);
+ entryV.getV(Vf);
+ for (int f = 0; f < factors; f++) {
+ Vi[f].set(Vf[f]);
+ }
+ forwardObjs[2] = null;
+ forwardObjs[3] = ViObj;
+ }
- if (LOG.isInfoEnabled()) {
- LOG.info("Forwarding a serialized/compressed model '" + modelId + "' of size: "
- + NumberUtils.prettySize(serialized.length));
+ forward(forwardObjs);
}
-
- Text modelObj = new Text3(serialized);
- serialized = null;
- Object[] forwardObjs = new Object[] {modelId, modelObj};
-
- forward(forwardObjs);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/IntFeature.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/IntFeature.java b/core/src/main/java/hivemall/fm/IntFeature.java
index 2052f7e..64a4daa 100644
--- a/core/src/main/java/hivemall/fm/IntFeature.java
+++ b/core/src/main/java/hivemall/fm/IntFeature.java
@@ -20,19 +20,21 @@ package hivemall.fm;
import java.nio.ByteBuffer;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
public final class IntFeature extends Feature {
+ @Nonnegative
private int index;
/** -1 if not defined */
private short field;
- public IntFeature(int index, double value) {
+ public IntFeature(@Nonnegative int index, double value) {
this(index, (short) -1, value);
}
- public IntFeature(int index, short field, double value) {
+ public IntFeature(@Nonnegative int index, short field, double value) {
super(value);
this.field = field;
this.index = index;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
index 6aebd64..3ec6ad7 100644
--- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
@@ -19,15 +19,18 @@
package hivemall.ftvec.pairing;
import hivemall.UDTFWithOptions;
+import hivemall.fm.Feature;
import hivemall.model.FeatureValue;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
@@ -50,6 +53,8 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
private Type _type;
private RowProcessor _proc;
+ private int _numFields;
+ private int _numFeatures;
public FeaturePairsUDTF() {}
@@ -57,9 +62,14 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
protected Options getOptions() {
Options opts = new Options();
opts.addOption("kpa", false,
- "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]");
+ "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:false]");
opts.addOption("ffm", false,
"Generate feature pairs for Field-aware Factorization Machines [default:false]");
+ // feature hashing
+ opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
+ opts.addOption("feature_hashing", true,
+ "The number of bits for feature hashing in range [18,31]. [default: -1] No feature hashing for -1.");
+ opts.addOption("num_fields", true, "The number of fields [default:1024]");
return opts;
}
@@ -70,13 +80,30 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
String args = HiveUtils.getConstString(argOIs[1]);
cl = parseOptions(args);
- Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class,
- "Only one option can be specified: " + cl.getArgList());
+ Preconditions.checkArgument(cl.getOptions().length <= 3, UDFArgumentException.class,
+ "Too many options were specified: " + cl.getArgList());
if (cl.hasOption("kpa")) {
this._type = Type.kpa;
} else if (cl.hasOption("ffm")) {
this._type = Type.ffm;
+ this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1);
+ if (_numFeatures == -1) {
+ int featureBits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1);
+ if (featureBits != -1) {
+ if (featureBits < 18 || featureBits > 31) {
+ throw new UDFArgumentException(
+ "-feature_hashing MUST be in range [18,31]: " + featureBits);
+ }
+ this._numFeatures = 1 << featureBits;
+ }
+ }
+ this._numFields = Primitives.parseInt(cl.getOptionValue("num_fields"),
+ Feature.DEFAULT_NUM_FIELDS);
+ if (_numFields <= 1) {
+ throw new UDFArgumentException("-num_fields MUST be greater than 1: "
+ + _numFields);
+ }
} else {
throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0));
}
@@ -113,8 +140,16 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
break;
}
case ffm: {
- throw new UDFArgumentException("-ffm is not supported yet");
- //break;
+ this._proc = new FFMProcessor(fvOI);
+ fieldNames.add("i"); // <ei, jField> index
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("j"); // <ej, iField> index
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("xi");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("xj");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ break;
}
default:
throw new UDFArgumentException("Illegal condition: " + _type);
@@ -144,26 +179,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
this.fvOI = fvOI;
}
- void process(@Nonnull Object arg) throws HiveException {
- final int size = fvOI.getListLength(arg);
- if (size == 0) {
- return;
- }
-
- final List<FeatureValue> features = new ArrayList<FeatureValue>(size);
- for (int i = 0; i < size; i++) {
- Object f = fvOI.getListElement(arg, i);
- if (f == null) {
- continue;
- }
- FeatureValue fv = FeatureValue.parse(f, true);
- features.add(fv);
- }
-
- process(features);
- }
-
- abstract void process(@Nonnull List<FeatureValue> features) throws HiveException;
+ abstract void process(@Nonnull Object arg) throws HiveException;
}
@@ -186,7 +202,22 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
}
@Override
- void process(@Nonnull List<FeatureValue> features) throws HiveException {
+ void process(@Nonnull Object arg) throws HiveException {
+ final int size = fvOI.getListLength(arg);
+ if (size == 0) {
+ return;
+ }
+
+ final List<FeatureValue> features = new ArrayList<FeatureValue>(size);
+ for (int i = 0; i < size; i++) {
+ Object f = fvOI.getListElement(arg, i);
+ if (f == null) {
+ continue;
+ }
+ FeatureValue fv = FeatureValue.parse(f, true);
+ features.add(fv);
+ }
+
forward[0] = f0;
f0.set(0);
forward[1] = null;
@@ -222,6 +253,78 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
}
}
+ final class FFMProcessor extends RowProcessor {
+
+ @Nonnull
+ private final IntWritable f0, f1;
+ @Nonnull
+ private final DoubleWritable f2, f3;
+ @Nonnull
+ private final Writable[] forward;
+
+ @Nullable
+ private transient Feature[] _features;
+
+ FFMProcessor(@Nonnull ListObjectInspector fvOI) {
+ super(fvOI);
+ this.f0 = new IntWritable();
+ this.f1 = new IntWritable();
+ this.f2 = new DoubleWritable();
+ this.f3 = new DoubleWritable();
+ this.forward = new Writable[] {f0, null, null, null};
+ this._features = null;
+ }
+
+ @Override
+ void process(@Nonnull Object arg) throws HiveException {
+ final int size = fvOI.getListLength(arg);
+ if (size == 0) {
+ return;
+ }
+
+ this._features = Feature.parseFFMFeatures(arg, fvOI, _features, _numFeatures,
+ _numFields);
+
+ // W0
+ f0.set(0);
+ forward[1] = null;
+ forward[2] = null;
+ forward[3] = null;
+ forward(forward);
+
+ forward[2] = f2;
+ final Feature[] features = _features;
+ for (int i = 0, len = features.length; i < len; i++) {
+ Feature ei = features[i];
+
+ // Wi
+ f0.set(Feature.toIntFeature(ei));
+ forward[1] = null;
+ f2.set(ei.getValue());
+ forward[3] = null;
+ forward(forward);
+
+ forward[1] = f1;
+ forward[3] = f3;
+ final int iField = ei.getField();
+ for (int j = i + 1; j < len; j++) {
+ Feature ej = features[j];
+ double xj = ej.getValue();
+ int jField = ej.getField();
+
+ int ifj = Feature.toIntFeature(ei, jField, _numFields);
+ int jfi = Feature.toIntFeature(ej, iField, _numFields);
+
+ // Vifj, Vjfi
+ f0.set(ifj);
+ f1.set(jfi);
+ // `f2` is consistently set to `xi`
+ f3.set(xj);
+ forward(forward);
+ }
+ }
+ }
+ }
@Override
public void close() throws HiveException {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
index 5e9f797..cdba00b 100644
--- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
+++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java
@@ -19,8 +19,8 @@
package hivemall.ftvec.ranking;
import hivemall.utils.collections.lists.IntArrayList;
-import hivemall.utils.collections.maps.IntOpenHashMap;
-import hivemall.utils.collections.maps.IntOpenHashMap.IMapIterator;
+import hivemall.utils.collections.maps.IntOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
import java.util.BitSet;
@@ -30,13 +30,13 @@ import javax.annotation.Nullable;
public class PositiveOnlyFeedback {
@Nonnull
- protected final IntOpenHashMap<IntArrayList> rows;
+ protected final IntOpenHashTable<IntArrayList> rows;
protected int maxItemId;
protected int totalFeedbacks;
public PositiveOnlyFeedback(int maxItemId) {
- this.rows = new IntOpenHashMap<IntArrayList>(1024);
+ this.rows = new IntOpenHashTable<IntArrayList>(1024);
this.maxItemId = maxItemId;
this.totalFeedbacks = 0;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
new file mode 100644
index 0000000..53b998c
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.trans;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "add_field_indicies", value = "_FUNC_(array<string> features) "
+ + "- Returns arrays of string that field indicies (<field>:<feature>)* are argumented")
+@UDFType(deterministic = true, stateful = false)
+public final class AddFieldIndicesUDF extends GenericUDF {
+
+ private ListObjectInspector listOI;
+
+ @Override
+ public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ if (argOIs.length != 1) {
+ throw new UDFArgumentException("Expected a single argument: " + argOIs.length);
+ }
+
+ this.listOI = HiveUtils.asListOI(argOIs[0]);
+ if (!HiveUtils.isStringOI(listOI.getListElementObjectInspector())) {
+ throw new UDFArgumentException("Expected array<string> but got " + argOIs[0]);
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+ }
+
+ @Override
+ public List<String> evaluate(@Nonnull DeferredObject[] args) throws HiveException {
+ Preconditions.checkArgument(args.length == 1);
+
+ final String[] features = HiveUtils.asStringArray(args[0], listOI);
+ if (features == null) {
+ return null;
+ }
+
+ final List<String> argumented = new ArrayList<>(features.length);
+ for (int i = 0; i < features.length; i++) {
+ final String f = features[i];
+ if (f == null) {
+ continue;
+ }
+ argumented.add((i + 1) + ":" + f);
+ }
+
+ return argumented;
+ }
+
+ @Override
+ public String getDisplayString(String[] args) {
+ return "add_field_indicies( " + Arrays.toString(args) + " )";
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java
index 98617bd..4722efd 100644
--- a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java
@@ -18,6 +18,7 @@
*/
package hivemall.ftvec.trans;
+import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
@@ -26,26 +27,55 @@ import java.util.List;
import javax.annotation.Nonnull;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.io.Text;
-@Description(name = "categorical_features",
- value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>")
+@Description(
+ name = "categorical_features",
+ value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])"
+ + " - Returns a feature vector array<string>")
@UDFType(deterministic = true, stateful = false)
-public final class CategoricalFeaturesUDF extends GenericUDF {
+public final class CategoricalFeaturesUDF extends UDFWithOptions {
- private String[] featureNames;
- private PrimitiveObjectInspector[] inputOIs;
- private List<Text> result;
+ private String[] _featureNames;
+ private PrimitiveObjectInspector[] _inputOIs;
+ private List<String> _result;
+
+ private boolean _emitNull = false;
+ private boolean _forceValue = false;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("no_elim", "no_elimination", false,
+ "Wheather to emit NULL and value [default: false]");
+ opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]");
+ opts.addOption("force_value", false, "Wheather to force emit value [default: false]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+ if (cl.hasOption("no_elim")) {
+ this._emitNull = true;
+ this._forceValue = true;
+ } else {
+ this._emitNull = cl.hasOption("emit_null");
+ this._forceValue = cl.hasOption("force_value");
+ }
+ return cl;
+ }
@Override
public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs)
@@ -55,54 +85,91 @@ public final class CategoricalFeaturesUDF extends GenericUDF {
throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: "
+ numArgOIs);
}
- this.featureNames = HiveUtils.getConstStringArray(argOIs[0]);
- if (featureNames == null) {
+
+ this._featureNames = HiveUtils.getConstStringArray(argOIs[0]);
+ if (_featureNames == null) {
throw new UDFArgumentException("#featureNames should not be null");
}
- int numFeatureNames = featureNames.length;
+ int numFeatureNames = _featureNames.length;
if (numFeatureNames < 1) {
throw new UDFArgumentException("#featureNames must be greater than or equals to 1: "
+ numFeatureNames);
}
- int numFeatures = numArgOIs - 1;
+ for (String featureName : _featureNames) {
+ if (featureName == null) {
+ throw new UDFArgumentException("featureName should not be null: "
+ + Arrays.toString(_featureNames));
+ } else if (featureName.indexOf(':') != -1) {
+ throw new UDFArgumentException("featureName should not include colon: "
+ + featureName);
+ }
+ }
+
+ final int numFeatures;
+ final int lastArgIndex = numArgOIs - 1;
+ if (lastArgIndex > numFeatureNames) {
+ if (lastArgIndex == (numFeatureNames + 1)
+ && HiveUtils.isConstString(argOIs[lastArgIndex])) {
+ String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]);
+ processOptions(optionValue);
+ numFeatures = numArgOIs - 2;
+ } else {
+ throw new UDFArgumentException(
+ "Unexpected arguments for _FUNC_"
+ + "(const array<string> featureNames, feature1, feature2, .. [, const string options])");
+ }
+ } else {
+ numFeatures = lastArgIndex;
+ }
if (numFeatureNames != numFeatures) {
- throw new UDFArgumentException("#featureNames '" + numFeatureNames
- + "' != #arguments '" + numFeatures + "'");
+ throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames
+ + "' != #features '" + numFeatures + "'");
}
- this.inputOIs = new PrimitiveObjectInspector[numFeatures];
+ this._inputOIs = new PrimitiveObjectInspector[numFeatures];
for (int i = 0; i < numFeatures; i++) {
ObjectInspector oi = argOIs[i + 1];
- inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi);
+ _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi);
}
- this.result = new ArrayList<Text>(numFeatures);
+ this._result = new ArrayList<String>(numFeatures);
- return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
}
@Override
- public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException {
- result.clear();
+ public List<String> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException {
+ _result.clear();
- final int size = arguments.length - 1;
+ final int size = _featureNames.length;
for (int i = 0; i < size; i++) {
Object argument = arguments[i + 1].get();
if (argument == null) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
- PrimitiveObjectInspector oi = inputOIs[i];
+ PrimitiveObjectInspector oi = _inputOIs[i];
String s = PrimitiveObjectInspectorUtils.getString(argument, oi);
if (s.isEmpty()) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
- // categorical feature representation
- String featureName = featureNames[i];
- Text f = new Text(featureName + '#' + s);
- result.add(f);
+ // categorical feature representation
+ final String f;
+ if (_forceValue) {
+ f = _featureNames[i] + '#' + s + ":1";
+ } else {
+ f = _featureNames[i] + '#' + s;
+ }
+ _result.add(f);
+
}
- return result;
+ return _result;
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
index c98ffda..eead738 100644
--- a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java
@@ -23,6 +23,7 @@ import hivemall.fm.Feature;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
@@ -59,6 +60,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
private boolean _mhash = true;
private int _numFeatures = Feature.DEFAULT_NUM_FEATURES;
private int _numFields = Feature.DEFAULT_NUM_FIELDS;
+ private boolean _emitIndicies = false;
@Override
protected Options getOptions() {
@@ -66,9 +68,11 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
opts.addOption("no_hash", "disable_feature_hashing", false,
"Wheather to disable feature hashing [default: false]");
// feature hashing
+ opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
opts.addOption("hash", "feature_hashing", true,
"The number of bits for feature hashing in range [18,31] [default:21]");
opts.addOption("fields", "num_fields", true, "The number of fields [default:1024]");
+ opts.addOption("emit_indicies", false, "Emit indicies for fields [default: false]");
return opts;
}
@@ -77,19 +81,27 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
CommandLine cl = parseOptions(optionValue);
// feature hashing
- int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"),
- Feature.DEFAULT_FEATURE_BITS);
- if (hashbits < 18 || hashbits > 31) {
- throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + hashbits);
+ int numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1);
+ if (numFeatures == -1) {
+ int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"),
+ Feature.DEFAULT_FEATURE_BITS);
+ if (hashbits < 18 || hashbits > 31) {
+ throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: "
+ + hashbits);
+ }
+ numFeatures = 1 << hashbits;
}
- int numFeatures = 1 << hashbits;
+ this._numFeatures = numFeatures;
+
int numFields = Primitives.parseInt(cl.getOptionValue("num_fields"),
Feature.DEFAULT_NUM_FIELDS);
if (numFields <= 1) {
throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields);
}
- this._numFeatures = numFeatures;
this._numFields = numFields;
+
+ this._emitIndicies = cl.hasOption("emit_indicies");
+
return cl;
}
@@ -111,7 +123,10 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
+ numFeatureNames);
}
for (String featureName : _featureNames) {
- if (featureName.indexOf(':') != -1) {
+ if (featureName == null) {
+ throw new UDFArgumentException("featureName should not be null: "
+ + Arrays.toString(_featureNames));
+ } else if (featureName.indexOf(':') != -1) {
throw new UDFArgumentException("featureName should not include colon: "
+ featureName);
}
@@ -174,18 +189,20 @@ public final class FFMFeaturesUDF extends UDFWithOptions {
// categorical feature representation
final String fv;
if (_mhash) {
- int field = MurmurHash3.murmurhash3(_featureNames[i], _numFields);
+ int field = _emitIndicies ? i : MurmurHash3.murmurhash3(_featureNames[i],
+ _numFields);
// +NUM_FIELD to avoid conflict to quantitative features
int index = MurmurHash3.murmurhash3(feature, _numFeatures) + _numFields;
fv = builder.append(field).append(':').append(index).append(":1").toString();
- builder.setLength(0);
+ StringUtils.clear(builder);
} else {
- fv = builder.append(featureName)
- .append(':')
- .append(feature)
- .append(":1")
- .toString();
- builder.setLength(0);
+ if (_emitIndicies) {
+ builder.append(i);
+ } else {
+ builder.append(featureName);
+ }
+ fv = builder.append(':').append(feature).append(":1").toString();
+ StringUtils.clear(builder);
}
_result.add(new Text(fv));
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java
index 2886996..846be97 100644
--- a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java
@@ -23,6 +23,7 @@ import hivemall.utils.lang.Identifier;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
@@ -39,7 +40,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
@Description(
name = "quantified_features",
- value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dence array<double>")
+ value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dense array<double>")
public final class QuantifiedFeaturesUDTF extends GenericUDTF {
private BooleanObjectInspector boolOI;
@@ -76,8 +77,8 @@ public final class QuantifiedFeaturesUDTF extends GenericUDTF {
}
}
- ArrayList<String> fieldNames = new ArrayList<String>(outputSize);
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(outputSize);
+ List<String> fieldNames = new ArrayList<String>(outputSize);
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(outputSize);
fieldNames.add("features");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java
index 43f837f..38e35e2 100644
--- a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java
@@ -18,6 +18,7 @@
*/
package hivemall.ftvec.trans;
+import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
@@ -26,11 +27,13 @@ import java.util.List;
import javax.annotation.Nonnull;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
@@ -39,14 +42,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.Text;
-@Description(name = "quantitative_features",
- value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>")
+@Description(
+ name = "quantitative_features",
+ value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])"
+ + " - Returns a feature vector array<string>")
@UDFType(deterministic = true, stateful = false)
-public final class QuantitativeFeaturesUDF extends GenericUDF {
+public final class QuantitativeFeaturesUDF extends UDFWithOptions {
- private String[] featureNames;
- private PrimitiveObjectInspector[] inputOIs;
- private List<Text> result;
+ private String[] _featureNames;
+ private PrimitiveObjectInspector[] _inputOIs;
+ private List<Text> _result;
+
+ private boolean _emitNull = false;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+ this._emitNull = cl.hasOption("emit_null");
+ return cl;
+ }
@Override
public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs)
@@ -56,58 +77,92 @@ public final class QuantitativeFeaturesUDF extends GenericUDF {
throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: "
+ numArgOIs);
}
- this.featureNames = HiveUtils.getConstStringArray(argOIs[0]);
- if (featureNames == null) {
+
+ this._featureNames = HiveUtils.getConstStringArray(argOIs[0]);
+ if (_featureNames == null) {
throw new UDFArgumentException("#featureNames should not be null");
}
- int numFeatureNames = featureNames.length;
+ int numFeatureNames = _featureNames.length;
if (numFeatureNames < 1) {
throw new UDFArgumentException("#featureNames must be greater than or equals to 1: "
+ numFeatureNames);
}
- int numFeatures = numArgOIs - 1;
+ for (String featureName : _featureNames) {
+ if (featureName == null) {
+ throw new UDFArgumentException("featureName should not be null: "
+ + Arrays.toString(_featureNames));
+ } else if (featureName.indexOf(':') != -1) {
+ throw new UDFArgumentException("featureName should not include colon: "
+ + featureName);
+ }
+ }
+
+ final int numFeatures;
+ final int lastArgIndex = numArgOIs - 1;
+ if (lastArgIndex > numFeatureNames) {
+ if (lastArgIndex == (numFeatureNames + 1)
+ && HiveUtils.isConstString(argOIs[lastArgIndex])) {
+ String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]);
+ processOptions(optionValue);
+ numFeatures = numArgOIs - 2;
+ } else {
+ throw new UDFArgumentException(
+ "Unexpected arguments for _FUNC_"
+ + "(const array<string> featureNames, feature1, feature2, .. [, const string options])");
+ }
+ } else {
+ numFeatures = lastArgIndex;
+ }
if (numFeatureNames != numFeatures) {
- throw new UDFArgumentException("#featureNames '" + numFeatureNames
- + "' != #arguments '" + numFeatures + "'");
+ throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames
+ + "' != #features '" + numFeatures + "'");
}
- this.inputOIs = new PrimitiveObjectInspector[numFeatures];
+ this._inputOIs = new PrimitiveObjectInspector[numFeatures];
for (int i = 0; i < numFeatures; i++) {
ObjectInspector oi = argOIs[i + 1];
- inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi);
+ _inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi);
}
- this.result = new ArrayList<Text>(numFeatures);
+ this._result = new ArrayList<Text>(numFeatures);
return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
}
@Override
public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException {
- result.clear();
+ _result.clear();
- final int size = arguments.length - 1;
+ final int size = _featureNames.length;
for (int i = 0; i < size; i++) {
Object argument = arguments[i + 1].get();
if (argument == null) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
- PrimitiveObjectInspector oi = inputOIs[i];
+ PrimitiveObjectInspector oi = _inputOIs[i];
if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) {
String s = argument.toString();
if (s.isEmpty()) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
}
final double v = PrimitiveObjectInspectorUtils.getDouble(argument, oi);
if (v != 0.d) {
- String featureName = featureNames[i];
- Text f = new Text(featureName + ':' + v);
- result.add(f);
+ Text f = new Text(_featureNames[i] + ':' + v);
+ _result.add(f);
+ } else if (_emitNull) {
+ Text f = new Text(_featureNames[i] + ":0");
+ _result.add(f);
}
}
- return result;
+ return _result;
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java
index 48bf126..f2ecbb6 100644
--- a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java
+++ b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java
@@ -18,6 +18,7 @@
*/
package hivemall.ftvec.trans;
+import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.StringUtils;
@@ -27,11 +28,13 @@ import java.util.List;
import javax.annotation.Nonnull;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
@@ -40,14 +43,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.Text;
-@Description(name = "vectorize_features",
- value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>")
+@Description(
+ name = "vectorize_features",
+ value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])"
+ + " - Returns a feature vector array<string>")
@UDFType(deterministic = true, stateful = false)
-public final class VectorizeFeaturesUDF extends GenericUDF {
+public final class VectorizeFeaturesUDF extends UDFWithOptions {
- private String[] featureNames;
- private PrimitiveObjectInspector[] inputOIs;
- private List<Text> result;
+ private String[] _featureNames;
+ private PrimitiveObjectInspector[] _inputOIs;
+ private List<Text> _result;
+
+ private boolean _emitNull = false;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+ this._emitNull = cl.hasOption("emit_null");
+ return cl;
+ }
@Override
public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs)
@@ -57,63 +78,96 @@ public final class VectorizeFeaturesUDF extends GenericUDF {
throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: "
+ numArgOIs);
}
- this.featureNames = HiveUtils.getConstStringArray(argOIs[0]);
- if (featureNames == null) {
+
+ this._featureNames = HiveUtils.getConstStringArray(argOIs[0]);
+ if (_featureNames == null) {
throw new UDFArgumentException("#featureNames should not be null");
}
- int numFeatureNames = featureNames.length;
+ int numFeatureNames = _featureNames.length;
if (numFeatureNames < 1) {
throw new UDFArgumentException("#featureNames must be greater than or equals to 1: "
+ numFeatureNames);
}
- int numFeatures = numArgOIs - 1;
+ for (String featureName : _featureNames) {
+ if (featureName == null) {
+ throw new UDFArgumentException("featureName should not be null: "
+ + Arrays.toString(_featureNames));
+ } else if (featureName.indexOf(':') != -1) {
+ throw new UDFArgumentException("featureName should not include colon: "
+ + featureName);
+ }
+ }
+
+ final int numFeatures;
+ final int lastArgIndex = numArgOIs - 1;
+ if (lastArgIndex > numFeatureNames) {
+ if (lastArgIndex == (numFeatureNames + 1)
+ && HiveUtils.isConstString(argOIs[lastArgIndex])) {
+ String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]);
+ processOptions(optionValue);
+ numFeatures = numArgOIs - 2;
+ } else {
+ throw new UDFArgumentException(
+ "Unexpected arguments for _FUNC_"
+ + "(const array<string> featureNames, feature1, feature2, .. [, const string options])");
+ }
+ } else {
+ numFeatures = lastArgIndex;
+ }
if (numFeatureNames != numFeatures) {
- throw new UDFArgumentException("#featureNames '" + numFeatureNames
- + "' != #arguments '" + numFeatures + "'");
+ throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames
+ + "' != #features '" + numFeatures + "'");
}
- this.inputOIs = new PrimitiveObjectInspector[numFeatures];
+ this._inputOIs = new PrimitiveObjectInspector[numFeatures];
for (int i = 0; i < numFeatures; i++) {
ObjectInspector oi = argOIs[i + 1];
- inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi);
+ _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi);
}
- this.result = new ArrayList<Text>(numFeatures);
+ this._result = new ArrayList<Text>(numFeatures);
return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
}
@Override
public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException {
- result.clear();
+ _result.clear();
- final int size = arguments.length - 1;
+ final int size = _featureNames.length;
for (int i = 0; i < size; i++) {
Object argument = arguments[i + 1].get();
if (argument == null) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
- PrimitiveObjectInspector oi = inputOIs[i];
+ PrimitiveObjectInspector oi = _inputOIs[i];
if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) {
String s = PrimitiveObjectInspectorUtils.getString(argument, oi);
if (s.isEmpty()) {
+ if (_emitNull) {
+ _result.add(null);
+ }
continue;
}
- if (StringUtils.isNumber(s) == false) {// categorical feature representation
- String featureName = featureNames[i];
- Text f = new Text(featureName + '#' + s);
- result.add(f);
+ if (StringUtils.isNumber(s) == false) {// categorical feature representation
+ Text f = new Text(_featureNames[i] + '#' + s);
+ _result.add(f);
continue;
}
}
- float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi);
+ final float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi);
if (v != 0.f) {
- String featureName = featureNames[i];
- Text f = new Text(featureName + ':' + v);
- result.add(f);
+ Text f = new Text(_featureNames[i] + ':' + v);
+ _result.add(f);
+ } else if (_emitNull) {
+ Text f = new Text(_featureNames[i] + ":0");
+ _result.add(f);
}
}
- return result;
+ return _result;
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/mf/FactorizedModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java
index a4bea00..1b7140f 100644
--- a/core/src/main/java/hivemall/mf/FactorizedModel.java
+++ b/core/src/main/java/hivemall/mf/FactorizedModel.java
@@ -18,7 +18,7 @@
*/
package hivemall.mf;
-import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashTable;
import hivemall.utils.math.MathUtils;
import java.util.Random;
@@ -42,10 +42,10 @@ public final class FactorizedModel {
private int minIndex, maxIndex;
@Nonnull
private Rating meanRating;
- private IntOpenHashMap<Rating[]> users;
- private IntOpenHashMap<Rating[]> items;
- private IntOpenHashMap<Rating> userBias;
- private IntOpenHashMap<Rating> itemBias;
+ private IntOpenHashTable<Rating[]> users;
+ private IntOpenHashTable<Rating[]> items;
+ private IntOpenHashTable<Rating> userBias;
+ private IntOpenHashTable<Rating> itemBias;
private final Random[] randU, randI;
@@ -67,10 +67,10 @@ public final class FactorizedModel {
this.minIndex = 0;
this.maxIndex = 0;
this.meanRating = ratingInitializer.newRating(meanRating);
- this.users = new IntOpenHashMap<Rating[]>(expectedSize);
- this.items = new IntOpenHashMap<Rating[]>(expectedSize);
- this.userBias = new IntOpenHashMap<Rating>(expectedSize);
- this.itemBias = new IntOpenHashMap<Rating>(expectedSize);
+ this.users = new IntOpenHashTable<Rating[]>(expectedSize);
+ this.items = new IntOpenHashTable<Rating[]>(expectedSize);
+ this.userBias = new IntOpenHashTable<Rating>(expectedSize);
+ this.itemBias = new IntOpenHashTable<Rating>(expectedSize);
this.randU = newRandoms(factor, 31L);
this.randI = newRandoms(factor, 41L);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/AbstractPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
index 95935d3..cd298a7 100644
--- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java
+++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java
@@ -22,7 +22,7 @@ import hivemall.annotations.InternalAPI;
import hivemall.mix.MixedWeight;
import hivemall.mix.MixedWeight.WeightWithCovar;
import hivemall.mix.MixedWeight.WeightWithDelta;
-import hivemall.utils.collections.maps.IntOpenHashMap;
+import hivemall.utils.collections.maps.IntOpenHashTable;
import hivemall.utils.collections.maps.OpenHashMap;
import javax.annotation.Nonnull;
@@ -37,7 +37,7 @@ public abstract class AbstractPredictionModel implements PredictionModel {
private long numMixed;
private boolean cancelMixRequest;
- private IntOpenHashMap<MixedWeight> mixedRequests_i;
+ private IntOpenHashTable<MixedWeight> mixedRequests_i;
private OpenHashMap<Object, MixedWeight> mixedRequests_o;
public AbstractPredictionModel() {
@@ -58,7 +58,7 @@ public abstract class AbstractPredictionModel implements PredictionModel {
this.cancelMixRequest = cancelMixRequest;
if (cancelMixRequest) {
if (isDenseModel()) {
- this.mixedRequests_i = new IntOpenHashMap<MixedWeight>(327680);
+ this.mixedRequests_i = new IntOpenHashTable<MixedWeight>(327680);
} else {
this.mixedRequests_o = new OpenHashMap<Object, MixedWeight>(327680);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/NewSparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java
index 8326d22..5c0a6c7 100644
--- a/core/src/main/java/hivemall/model/NewSparseModel.java
+++ b/core/src/main/java/hivemall/model/NewSparseModel.java
@@ -194,7 +194,7 @@ public final class NewSparseModel extends AbstractPredictionModel {
@SuppressWarnings("unchecked")
@Override
public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
- return (IMapIterator<K, V>) weights.entries();
+ return (IMapIterator<K, V>) weights.entries(true);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index cb8ab9f..65e751d 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -183,7 +183,7 @@ public final class SparseModel extends AbstractPredictionModel {
@SuppressWarnings("unchecked")
@Override
public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
- return (IMapIterator<K, V>) weights.entries();
+ return (IMapIterator<K, V>) weights.entries(true);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java
index a2e3e55..6dbb7d5 100644
--- a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java
+++ b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java
@@ -18,6 +18,10 @@
*/
package hivemall.tools.array;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES1;
+import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
@@ -34,6 +38,7 @@ import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -220,7 +225,8 @@ public final class ArrayAvgGenericUDAF extends AbstractGenericUDAFResolver {
}
}
- public static class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer {
+ @AggregationType(estimable = true)
+ public static final class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer {
int _size;
// note that primitive array cannot be serialized by JDK serializer
@@ -289,6 +295,15 @@ public final class ArrayAvgGenericUDAF extends AbstractGenericUDAFResolver {
}
}
+ @Override
+ public int estimate() {
+ if (_size == -1) {
+ return JAVA64_REF;
+ } else {
+ return PRIMITIVES1 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * _size);
+ }
+ }
+
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java
index e0a3c9e..10051a9 100644
--- a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java
+++ b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java
@@ -20,7 +20,6 @@ package hivemall.utils.buffer;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
-import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.lang.UnsafeUtils;
@@ -97,8 +96,8 @@ public final class HeapBuffer {
Preconditions.checkArgument(bytes <= _chunkBytes,
"Cannot allocate memory greater than %s bytes: %s", _chunkBytes, bytes);
- int i = Primitives.castToInt(_position / _chunkBytes);
- final int j = Primitives.castToInt(_position % _chunkBytes);
+ int i = NumberUtils.castToInt(_position / _chunkBytes);
+ final int j = NumberUtils.castToInt(_position % _chunkBytes);
if (bytes > (_chunkBytes - j)) {
// cannot allocate the object in the current chunk
// so, skip the current chunk
@@ -144,7 +143,7 @@ public final class HeapBuffer {
public byte getByte(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getByte(chunk, j);
@@ -152,7 +151,7 @@ public final class HeapBuffer {
public void putByte(final long ptr, final byte value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putByte(chunk, j, value);
@@ -160,7 +159,7 @@ public final class HeapBuffer {
public int getInt(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getInt(chunk, j);
@@ -168,7 +167,7 @@ public final class HeapBuffer {
public void putInt(final long ptr, final int value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putInt(chunk, j, value);
@@ -176,7 +175,7 @@ public final class HeapBuffer {
public short getShort(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getShort(chunk, j);
@@ -184,7 +183,7 @@ public final class HeapBuffer {
public void putShort(final long ptr, final short value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putShort(chunk, j, value);
@@ -192,7 +191,7 @@ public final class HeapBuffer {
public char getChar(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getChar(chunk, j);
@@ -200,14 +199,14 @@ public final class HeapBuffer {
public void putChar(final long ptr, final char value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putChar(chunk, j, value);
}
public long getLong(final long ptr) {
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getLong(chunk, j);
@@ -215,7 +214,7 @@ public final class HeapBuffer {
public void putLong(final long ptr, final long value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putLong(chunk, j, value);
@@ -223,7 +222,7 @@ public final class HeapBuffer {
public float getFloat(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getFloat(chunk, j);
@@ -231,7 +230,7 @@ public final class HeapBuffer {
public void putFloat(final long ptr, final float value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putFloat(chunk, j, value);
@@ -239,7 +238,7 @@ public final class HeapBuffer {
public double getDouble(final long ptr) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
return _UNSAFE.getDouble(chunk, j);
@@ -247,7 +246,7 @@ public final class HeapBuffer {
public void putDouble(final long ptr, final double value) {
validatePointer(ptr);
- int i = Primitives.castToInt(ptr / _chunkBytes);
+ int i = NumberUtils.castToInt(ptr / _chunkBytes);
int[] chunk = _chunks[i];
long j = offset(ptr);
_UNSAFE.putDouble(chunk, j, value);
@@ -260,7 +259,7 @@ public final class HeapBuffer {
throw new IllegalArgumentException("Cannot put empty array at " + ptr);
}
- int chunkIdx = Primitives.castToInt(ptr / _chunkBytes);
+ int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes);
final int[] chunk = _chunks[chunkIdx];
final long base = offset(ptr);
for (int i = 0; i < len; i++) {
@@ -277,7 +276,7 @@ public final class HeapBuffer {
throw new IllegalArgumentException("Cannot put empty array at " + ptr);
}
- int chunkIdx = Primitives.castToInt(ptr / _chunkBytes);
+ int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes);
final int[] chunk = _chunks[chunkIdx];
final long base = offset(ptr);
for (int i = 0; i < len; i++) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
index f847b15..e9b5c8a 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
@@ -27,8 +27,13 @@ import java.io.ObjectOutput;
import java.util.Arrays;
/**
- * An open-addressing hash table with double hashing
- *
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
+ *
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
public class Int2FloatOpenHashTable implements Externalizable {
@@ -37,7 +42,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
index 5e9e812..8e87fce 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
@@ -27,7 +27,12 @@ import java.io.ObjectOutput;
import java.util.Arrays;
/**
- * An open-addressing hash table with double hashing
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
*
* @see http://en.wikipedia.org/wiki/Double_hashing
*/
@@ -37,7 +42,7 @@ public final class Int2IntOpenHashTable implements Externalizable {
protected static final byte FULL = 1;
protected static final byte REMOVED = 2;
- private static final float DEFAULT_LOAD_FACTOR = 0.7f;
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
private static final float DEFAULT_GROW_FACTOR = 2.0f;
protected final transient float _loadFactor;