You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/08/23 11:05:12 UTC
[2/2] incubator-hivemall git commit: [HIVEMALL-201] Evaluate,
fix and document FFM
[HIVEMALL-201] Evaluate, fix and document FFM
## What changes were proposed in this pull request?
Applied some refactoring to #149
This PR closes #149
## What type of PR is it?
Hot Fix, Refactoring
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-201
## How was this patch tested?
unit tests, manual tests
## How to use this feature?
Will be published at: http://hivemall.incubator.apache.org/userguide/binaryclass/criteo_ffm.html
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Takuya Kitazawa <k....@gmail.com>
Author: Makoto Yui <my...@apache.org>
Closes #155 from myui/HIVEMALL-201-2.
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/61711fbc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/61711fbc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/61711fbc
Branch: refs/heads/master
Commit: 61711fbc2be109a200f1773958b5a8c519f5a066
Parents: b88e9f5
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Thu Aug 23 20:05:04 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Aug 23 20:05:04 2018 +0900
----------------------------------------------------------------------
.../java/hivemall/fm/FMHyperParameters.java | 61 +++-
.../hivemall/fm/FactorizationMachineModel.java | 63 +++-
.../hivemall/fm/FactorizationMachineUDTF.java | 151 +++++---
.../fm/FieldAwareFactorizationMachineModel.java | 10 +-
.../fm/FieldAwareFactorizationMachineUDTF.java | 32 +-
.../ftvec/pairing/FeaturePairsUDTF.java | 10 +-
.../ftvec/scaling/L1NormalizationUDF.java | 5 +
.../ftvec/scaling/L2NormalizationUDF.java | 5 +
.../hivemall/mf/BPRMatrixFactorizationUDTF.java | 20 +-
.../mf/OnlineMatrixFactorizationUDTF.java | 20 +-
.../hivemall/tools/mapred/RowNumberUDF.java | 3 +-
.../java/hivemall/utils/lang/Primitives.java | 3 +
.../main/java/hivemall/utils/lang/SizeOf.java | 1 +
.../fm/FactorizationMachineUDTFTest.java | 135 +++++++
.../FieldAwareFactorizationMachineUDTFTest.java | 185 ++++++----
.../ftvec/scaling/L1NormalizationUDFTest.java | 6 +
.../ftvec/scaling/L2NormalizationUDFTest.java | 6 +
docs/gitbook/SUMMARY.md | 3 +
docs/gitbook/binaryclass/criteo.md | 20 ++
docs/gitbook/binaryclass/criteo_dataset.md | 97 +++++
docs/gitbook/binaryclass/criteo_ffm.md | 356 +++++++++++++++++++
21 files changed, 1011 insertions(+), 181 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FMHyperParameters.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java
index 0992325..edee14f 100644
--- a/core/src/main/java/hivemall/fm/FMHyperParameters.java
+++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java
@@ -28,7 +28,8 @@ import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
class FMHyperParameters {
- private static final float DEFAULT_ETA0 = 0.05f;
+ protected static final float DEFAULT_ETA0 = 0.1f;
+ protected static final float DEFAULT_LAMBDA = 0.0001f;
// -------------------------------------
// Model parameters
@@ -37,10 +38,10 @@ class FMHyperParameters {
int factors = 5;
// regularization
- float lambda = 0.01f;
- float lambdaW0 = 0.01f;
- float lambdaW = 0.01f;
- float lambdaV = 0.01f;
+ float lambda = DEFAULT_LAMBDA;
+ float lambdaW0;
+ float lambdaW;
+ float lambdaV;
// V initialization
double sigma = 0.1d;
@@ -62,10 +63,12 @@ class FMHyperParameters {
boolean l2norm; // enable by default for FFM. disabled by default for FM.
- int iters = 1;
+ int iters = 10;
boolean conversionCheck = true;
double convergenceRate = 0.005d;
+ boolean earlyStopping = false;
+
// adaptive regularization
boolean adaptiveRegularization = false;
float validationRatio = 0.05f;
@@ -89,10 +92,14 @@ class FMHyperParameters {
void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
this.classification = cl.hasOption("classification");
- this.factors = Primitives.parseInt(cl.getOptionValue("factors"), factors);
+ if (cl.hasOption("factor")) {
+ this.factors = Primitives.parseInt(cl.getOptionValue("factor"), factors);
+ } else {
+ this.factors = Primitives.parseInt(cl.getOptionValue("factors"), factors);
+ }
this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), lambda);
this.lambdaW0 = Primitives.parseFloat(cl.getOptionValue("lambda_w0"), lambda);
- this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_w"), lambda);
+ this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_wi"), lambda);
this.lambdaV = Primitives.parseFloat(cl.getOptionValue("lambda_v"), lambda);
this.sigma = Primitives.parseDouble(cl.getOptionValue("sigma"), sigma);
this.seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
@@ -105,10 +112,15 @@ class FMHyperParameters {
this.eta = EtaEstimator.get(cl, DEFAULT_ETA0);
this.numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), numFeatures);
this.l2norm = cl.hasOption("enable_norm");
- this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters);
+ if (cl.hasOption("iter")) {
+ this.iters = Primitives.parseInt(cl.getOptionValue("iter"), iters);
+ } else {
+ this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters);
+ }
this.conversionCheck = !cl.hasOption("disable_cvtest");
this.convergenceRate =
Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
+ this.earlyStopping = cl.hasOption("early_stopping");
this.adaptiveRegularization = cl.hasOption("adaptive_regularization");
this.validationRatio =
Primitives.parseFloat(cl.getOptionValue("validation_ratio"), validationRatio);
@@ -122,14 +134,13 @@ class FMHyperParameters {
}
@Nonnull
- private static VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed,
+ private VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed,
final boolean classification) {
String vInitOpt = cl.getOptionValue("init_v");
float maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f);
double initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
- VInitScheme defaultInit = classification ? VInitScheme.gaussian : VInitScheme.random;
- VInitScheme vInit = VInitScheme.resolve(vInitOpt, defaultInit);
+ VInitScheme vInit = VInitScheme.resolve(vInitOpt, getDefaultVinitScheme());
vInit.setMaxInitValue(maxInitValue);
initStdDev = Math.max(initStdDev, 1.0d / factor);
vInit.setInitStdDev(initStdDev);
@@ -137,11 +148,16 @@ class FMHyperParameters {
return vInit;
}
+ @Nonnull
+ protected VInitScheme getDefaultVinitScheme() {
+ return classification ? VInitScheme.gaussian : VInitScheme.adjustedRandom;
+ }
+
public static final class FFMHyperParameters extends FMHyperParameters {
// FFM hyper parameters
boolean globalBias = false;
- boolean linearCoeff = true;
+ boolean linearCoeff = false;
// feature hashing
int numFields = Feature.DEFAULT_NUM_FIELDS;
@@ -152,15 +168,20 @@ class FMHyperParameters {
// FTRL
boolean useFTRL = false;
- float alphaFTRL = 0.2f; // Learning Rate
- float betaFTRL = 1.f; // Smoothing parameter for AdaGrad
- float lambda1 = 0.001f; // L1 Regularization
+ float alphaFTRL = 0.5f; // Learning Rate
+ float betaFTRL = 1.0f; // Smoothing parameter for AdaGrad
+ float lambda1 = 0.0002f; // L1 Regularization
float lambda2 = 0.0001f; // L2 Regularization
FFMHyperParameters() {
super();
}
+ @Nonnull
+ protected VInitScheme getDefaultVinitScheme() {
+ return VInitScheme.random;
+ }
+
@Override
void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
super.processOptions(cl);
@@ -170,7 +191,13 @@ class FMHyperParameters {
}
this.globalBias = cl.hasOption("global_bias");
- this.linearCoeff = !cl.hasOption("no_coeff");
+ this.linearCoeff = cl.hasOption("linear_term");
+
+ if (cl.hasOption("enable_norm") && cl.hasOption("disable_norm")) {
+ throw new UDFArgumentException(
+ "-enable_norm and -disable_norm MUST NOT be used simultaneously");
+ }
+ this.l2norm = !cl.hasOption("disable_norm");
// feature hashing
if (numFeatures == -1) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
index bb97bef..c654f32 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
@@ -271,7 +271,7 @@ public abstract class FactorizationMachineModel {
* sum_f_dash := \sum_{j} x_j * v'_lj, this is independent of the groups
* sum_f(g) := \sum_{j \in group(g)} x_j * v_jf
* sum_f_dash_f(g) := \sum_{j \in group(g)} x^2_j * v_jf * v'_jf
- * := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf
+ * := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf
* v_jf' := v_jf - alpha ( grad_v_jf + 2 * lambda_v_f * v_jf)
* </pre>
*/
@@ -336,7 +336,7 @@ public abstract class FactorizationMachineModel {
public void check(@Nonnull Feature[] x) throws HiveException {}
public enum VInitScheme {
- random /* default */, gaussian;
+ adjustedRandom /* default */, libffmRandom, random, gaussian;
@Nonnegative
float maxInitValue;
@@ -346,7 +346,7 @@ public abstract class FactorizationMachineModel {
@Nonnull
public static VInitScheme resolve(@Nullable String opt) {
- return resolve(opt, random);
+ return resolve(opt, adjustedRandom);
}
@Nonnull
@@ -354,10 +354,16 @@ public abstract class FactorizationMachineModel {
@Nonnull VInitScheme defaultScheme) {
if (opt == null) {
return defaultScheme;
- } else if ("gaussian".equalsIgnoreCase(opt)) {
- return gaussian;
+ } else if ("adjusted_random".equalsIgnoreCase(opt)
+ || "adjustedRandom".equalsIgnoreCase(opt)) {
+ return adjustedRandom;
+ } else if ("libffm_random".equalsIgnoreCase(opt) || "libffmRandom".equalsIgnoreCase(opt)
+ || "libffm".equalsIgnoreCase(opt)) {
+ return VInitScheme.libffmRandom;
} else if ("random".equalsIgnoreCase(opt)) {
return random;
+ } else if ("gaussian".equalsIgnoreCase(opt)) {
+ return gaussian;
}
return defaultScheme;
}
@@ -371,7 +377,7 @@ public abstract class FactorizationMachineModel {
}
public void initRandom(int factor, long seed) {
- int size = (this == random) ? 1 : factor;
+ final int size = (this != gaussian) ? 1 : factor;
this.rand = new Random[size];
for (int i = 0; i < size; i++) {
rand[i] = new Random(seed + i);
@@ -383,8 +389,14 @@ public abstract class FactorizationMachineModel {
protected final float[] initV() {
final float[] ret = new float[_factor];
switch (_initScheme) {
+ case adjustedRandom:
+ adjustedRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
+ break;
+ case libffmRandom:
+ libffmRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
+ break;
case random:
- uniformFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
+ randomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue);
break;
case gaussian:
gaussianFill(ret, _initScheme.rand, _initScheme.initStdDev);
@@ -396,19 +408,42 @@ public abstract class FactorizationMachineModel {
return ret;
}
- protected static final void uniformFill(final float[] a, final Random rand,
- final float maxInitValue) {
- final int len = a.length;
- final float basev = maxInitValue / len;
- for (int i = 0; i < len; i++) {
+ protected static final void adjustedRandomFill(@Nonnull final float[] a,
+ @Nonnull final Random rand, final float maxInitValue) {
+ final int k = a.length;
+ final float basev = maxInitValue / k;
+ for (int i = 0; i < k; i++) {
float v = rand.nextFloat() * basev;
a[i] = v;
}
}
- protected static final void gaussianFill(final float[] a, final Random[] rand,
+ // libffm's V initialization scheme: 1/sqrt(k)
+ // https://github.com/guestwalk/libffm/blob/master/ffm.cpp#L287
+ protected static final void libffmRandomFill(@Nonnull final float[] a,
+ @Nonnull final Random rand, final float maxInitValue) {
+ final int k = a.length;
+ final float basev = maxInitValue / (float) Math.sqrt(k);
+ for (int i = 0; i < k; i++) {
+ float v = rand.nextFloat() * basev;
+ a[i] = v;
+ }
+ }
+
+ protected static final void randomFill(@Nonnull final float[] a, @Nonnull final Random rand,
+ final float maxInitValue) {
+ final int k = a.length;
+ for (int i = 0; i < k; i++) {
+ float v = rand.nextFloat() * maxInitValue;
+ a[i] = v;
+ }
+ }
+
+ // libfm uses gaussian for initialization
+ // https://github.com/srendle/libfm/blob/30b9c799c41d043f31565cbf827bf41d0dc3e2ab/src/fm_core/fm_model.h#L96
+ protected static final void gaussianFill(@Nonnull final float[] a, @Nonnull final Random[] rand,
final double stddev) {
- for (int i = 0, len = a.length; i < len; i++) {
+ for (int i = 0, k = a.length; i < k; i++) {
float v = (float) MathUtils.gaussian(0.d, stddev, rand[i]);
a[i] = v;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
index eadd451..a253729 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
@@ -18,6 +18,11 @@
*/
package hivemall.fm;
+import static hivemall.fm.FMHyperParameters.DEFAULT_ETA0;
+import static hivemall.fm.FMHyperParameters.DEFAULT_LAMBDA;
+import static hivemall.utils.lang.Primitives.FALSE_BYTE;
+import static hivemall.utils.lang.Primitives.TRUE_BYTE;
+
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
@@ -65,6 +70,8 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
+import com.google.common.base.Preconditions;
+
@Description(name = "train_fm",
value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
public class FactorizationMachineUDTF extends UDTFWithOptions {
@@ -89,7 +96,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
protected int _factors;
protected boolean _parseFeatureAsInt;
+ protected boolean _earlyStopping;
+ protected ConversionState _validationState;
+
// adaptive regularization
+ protected boolean _adaptiveRegularization;
@Nullable
protected Random _va_rand;
protected float _validationRatio;
@@ -107,6 +118,10 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
* The number of training examples processed
*/
protected long _t;
+ /**
+ * The number of validation examples
+ */
+ protected long _numValidations;
// file IO
private ByteBuffer _inputBuf;
@@ -117,24 +132,28 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
Options opts = new Options();
opts.addOption("c", "classification", false, "Act as classification");
opts.addOption("seed", true, "Seed value [default: -1 (random)]");
- opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
+ opts.addOption("iters", "iterations", true, "The number of iterations [default: 10]");
+ opts.addOption("iter", true, "The number of iterations [default: 10]."
+ + " Note this is alias of `iters` for backward compatibility");
opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]");
+ opts.addOption("k", "factor", true,
+ "The number of the latent variables [default: 5]" + " Alias of `-factors` option");
opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]");
opts.addOption("lambda0", "lambda", true,
- "The initial lambda value for regularization [default: 0.01]");
+ "The initial lambda value for regularization [default: " + DEFAULT_LAMBDA + "]");
opts.addOption("lambdaW0", "lambda_w0", true,
- "The initial lambda value for W0 regularization [default: 0.01]");
+ "The initial lambda value for W0 regularization [default: " + DEFAULT_LAMBDA + "]");
opts.addOption("lambdaWi", "lambda_wi", true,
- "The initial lambda value for Wi regularization [default: 0.01]");
+ "The initial lambda value for Wi regularization [default: " + DEFAULT_LAMBDA + "]");
opts.addOption("lambdaV", "lambda_v", true,
- "The initial lambda value for V regularization [default: 0.01]");
+ "The initial lambda value for V regularization [default: " + DEFAULT_LAMBDA + "]");
// regression
opts.addOption("min", "min_target", true, "The minimum value of target variable");
opts.addOption("max", "max_target", true, "The maximum value of target variable");
// learning rates
opts.addOption("eta", true, "The initial learning rate");
- opts.addOption("eta0", true, "The initial learning rate [default 0.05]");
+ opts.addOption("eta0", true, "The initial learning rate [default " + DEFAULT_ETA0 + "]");
opts.addOption("t", "total_steps", true, "The total number of training examples");
opts.addOption("power_t", true,
"The exponent for inverse scaling learning rate [default 0.1]");
@@ -143,19 +162,22 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
"Whether to disable convergence check [default: OFF]");
opts.addOption("cv_rate", "convergence_rate", true,
"Threshold to determine convergence [default: 0.005]");
- // adaptive regularization
+ // adaptive regularization and early stopping with randomly hold-out validation samples
+ opts.addOption("early_stopping", false,
+ "Stop at the iteration that achieves the best validation on partial samples [default: OFF]");
+ opts.addOption("va_ratio", "validation_ratio", true,
+ "Ratio of training data used for validation [default: 0.05f]");
+ opts.addOption("va_threshold", "validation_threshold", true,
+ "Threshold to start validation. "
+ + "At least N training examples are used before validation [default: 1000]");
if (isAdaptiveRegularizationSupported()) {
opts.addOption("adareg", "adaptive_regularization", false,
"Whether to enable adaptive regularization [default: OFF]");
- opts.addOption("va_ratio", "validation_ratio", true,
- "Ratio of training data used for validation [default: 0.05f]");
- opts.addOption("va_threshold", "validation_threshold", true,
- "Threshold to start validation. "
- + "At least N training examples are used before validation [default: 1000]");
}
// initialization of V
- opts.addOption("init_v", true, "Initialization strategy of matrix V [random, gaussian]"
- + "(default: 'random' for regression / 'gaussian' for classification)");
+ opts.addOption("init_v", true,
+ "Initialization strategy of matrix V [adjusted_random, libffm, random, gaussian]"
+ + "(FM default: 'adjusted_random' for regression, 'gaussian' for classification, FFM default: random)");
opts.addOption("maxval", "max_init_value", true,
"The maximum initial value in the matrix V [default: 0.5]");
opts.addOption("min_init_stddev", true,
@@ -188,9 +210,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
this._iterations = params.iters;
this._factors = params.factors;
this._parseFeatureAsInt = params.parseFeatureAsInt;
- if (params.adaptiveRegularization) {
+ this._earlyStopping = params.earlyStopping;
+ this._adaptiveRegularization = params.adaptiveRegularization;
+ if (_earlyStopping || _adaptiveRegularization) {
this._va_rand = new Random(params.seed + 31L);
}
+ this._validationState = new ConversionState();
this._validationRatio = params.validationRatio;
this._validationThreshold = params.validationThreshold;
this._lossFunction = params.classification ? LossFunctions.getLossFunction(LossType.LogLoss)
@@ -216,6 +241,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
this._model = null;
this._t = 0L;
+ this._numValidations = 0L;
if (LOG.isInfoEnabled()) {
LOG.info(_params);
@@ -276,16 +302,23 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
return;
}
this._probes = x;
+ _model.check(x); // mostly for FMIntFeatureMapModel
double y = PrimitiveObjectInspectorUtils.getDouble(args[1], _yOI);
if (_classification) {
y = (y > 0.d) ? 1.d : -1.d;
}
- ++_t;
- recordTrain(x, y);
- boolean adaptiveRegularization = (_va_rand != null) && _t >= _validationThreshold;
- train(x, y, adaptiveRegularization);
+ boolean validation = isValidationExample();
+ recordTrain(x, y, validation);
+ train(x, y, validation);
+ }
+
+ private boolean isValidationExample() {
+ if (_va_rand != null && _t >= _validationThreshold) {
+ return _va_rand.nextFloat() < _validationRatio;
+ }
+ return false;
}
@Nullable
@@ -297,7 +330,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
return features;
}
- protected void recordTrain(@Nonnull final Feature[] x, final double y) throws HiveException {
+ private void recordTrain(@Nonnull final Feature[] x, final double y, final boolean validation)
+ throws HiveException {
if (_iterations <= 1) {
return;
}
@@ -325,7 +359,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
}
int xBytes = Feature.requiredBytes(x);
- int recordBytes = SizeOf.INT + SizeOf.DOUBLE + xBytes;
+ int recordBytes = SizeOf.INT + SizeOf.DOUBLE + xBytes + SizeOf.BYTE;
int requiredBytes = SizeOf.INT + recordBytes;
int remain = inputBuf.remaining();
if (remain < requiredBytes) {
@@ -338,6 +372,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
f.writeTo(inputBuf);
}
inputBuf.putDouble(y);
+ if (validation) {
+ ++_numValidations;
+ inputBuf.put(TRUE_BYTE);
+ } else {
+ inputBuf.put(FALSE_BYTE);
+ }
}
private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst)
@@ -351,20 +391,13 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
srcBuf.clear();
}
- public void train(@Nonnull final Feature[] x, final double y,
- final boolean adaptiveRegularization) throws HiveException {
- _model.check(x);
-
+ private void train(@Nonnull final Feature[] x, final double y, final boolean validation)
+ throws HiveException {
try {
- if (adaptiveRegularization) {
- assert (_va_rand != null);
- final float rnd = _va_rand.nextFloat();
- if (rnd < _validationRatio) {
- trainLambda(x, y); // adaptive regularization
- } else {
- trainTheta(x, y);
- }
+ if (validation) {
+ processValidationSample(x, y);
} else {
+ ++_t;
trainTheta(x, y);
}
} catch (Exception ex) {
@@ -372,6 +405,18 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
}
}
+ protected void processValidationSample(@Nonnull final Feature[] x, final double y)
+ throws HiveException {
+ if (_earlyStopping) {
+ double p = _model.predict(x);
+ double loss = _lossFunction.loss(p, y);
+ _validationState.incrLoss(loss);
+ }
+ if (_adaptiveRegularization) {
+ trainLambda(x, y); // adaptive regularization
+ }
+ }
+
/**
* Update model parameters
*/
@@ -410,7 +455,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
* grad_lambdafg = (grad l(p,y)) * (-2 * alpha * (\sum_{l} x_l * v'_lf) * \sum_{l \in group(g)} x_l * v_lf) - \sum_{l \in group(g)} x^2_l * v_lf * v'_lf)
* </pre>
*/
- protected void trainLambda(final Feature[] x, final double y) throws HiveException {
+ private void trainLambda(final Feature[] x, final double y) throws HiveException {
final float eta = _etaEstimator.eta(_t);
final double p = _model.predict(x);
final double lossGrad = _model.dloss(p, y);
@@ -534,12 +579,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
}
protected void runTrainingIteration(int iterations) throws HiveException {
- final ByteBuffer inputBuf = this._inputBuf;
- final NioStatefulSegment fileIO = this._fileIO;
- assert (inputBuf != null);
- assert (fileIO != null);
+ final ByteBuffer inputBuf = Preconditions.checkNotNull(this._inputBuf);
+ final NioStatefulSegment fileIO = Preconditions.checkNotNull(this._fileIO);
+
final long numTrainingExamples = _t;
- final boolean adaregr = _va_rand != null;
+ boolean lossIncreasedLastIter = false;
final Reporter reporter = getReporter();
final Counter iterCounter = (reporter == null) ? null
@@ -553,6 +597,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
inputBuf.flip();
for (int iter = 2; iter <= iterations; iter++) {
+ _validationState.next();
_cvState.next();
reportProgress(reporter);
setCounterValue(iterCounter, iter);
@@ -566,19 +611,25 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
x[j] = instantiateFeature(inputBuf);
}
double y = inputBuf.getDouble();
+ boolean validation = (inputBuf.get() == TRUE_BYTE);
+
// invoke train
- ++_t;
- train(x, y, adaregr);
+ train(x, y, validation);
}
- if (_cvState.isConverged(numTrainingExamples)) {
+ // stop if validation loss is consecutively increased over recent 2 iterations
+ final boolean lossIncreased = _validationState.isLossIncreased();
+ if ((lossIncreasedLastIter && lossIncreased)
+ || _cvState.isConverged(numTrainingExamples)) {
break;
}
+ lossIncreasedLastIter = lossIncreased;
inputBuf.rewind();
}
LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(numTrainingExamples)
+ " training examples on memory (thus " + NumberUtils.formatNumber(_t)
- + " training updates in total) ");
+ + " training updates in total), used " + _numValidations
+ + " validation examples");
} else {// read training examples in the temporary file and invoke train for each example
// write training examples in buffer to a temporary file
@@ -601,6 +652,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
// run iterations
for (int iter = 2; iter <= iterations; iter++) {
+ _validationState.next();
_cvState.next();
setCounterValue(iterCounter, iter);
@@ -643,23 +695,28 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
x[j] = instantiateFeature(inputBuf);
}
double y = inputBuf.getDouble();
+ boolean validation = (inputBuf.get() == TRUE_BYTE);
// invoke training
- ++_t;
- train(x, y, adaregr);
+ train(x, y, validation);
remain -= recordBytes;
}
inputBuf.compact();
}
- if (_cvState.isConverged(numTrainingExamples)) {
+ // stop if validation loss is consecutively increased over recent 2 iterations
+ final boolean lossIncreased = _validationState.isLossIncreased();
+ if ((lossIncreasedLastIter && lossIncreased)
+ || _cvState.isConverged(numTrainingExamples)) {
break;
}
+ lossIncreasedLastIter = lossIncreased;
}
LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(numTrainingExamples)
+ " training examples on a secondary storage (thus "
- + NumberUtils.formatNumber(_t) + " training updates in total)");
+ + NumberUtils.formatNumber(_t) + " training updates in total), used "
+ + _numValidations + " validation examples");
}
} finally {
// delete the temporary file and release resources
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
index c6c0fd0..6cd8fe8 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
@@ -36,13 +36,12 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
@Nonnull
protected final FFMHyperParameters _params;
- protected final float _eta0;
protected final float _eps;
protected final boolean _useAdaGrad;
protected final boolean _useFTRL;
- // FTEL
+ // FTRL
private final float _alpha;
private final float _beta;
private final float _lambda1;
@@ -51,11 +50,6 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) {
super(params);
this._params = params;
- if (params.useAdaGrad) {
- this._eta0 = 1.0f;
- } else {
- this._eta0 = params.eta.eta0();
- }
this._eps = params.eps;
this._useAdaGrad = params.useAdaGrad;
this._useFTRL = params.useFTRL;
@@ -261,7 +255,7 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM
if (_useAdaGrad) {
double gg = theta.getSumOfSquaredGradients(f);
theta.addGradient(f, grad);
- return (float) (_eta0 / Math.sqrt(_eps + gg));
+ return (float) (_eta.eta(t) / Math.sqrt(_eps + gg));
} else {
return _eta.eta(t);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index 610fa3d..7987086 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -52,7 +52,7 @@ import org.apache.hadoop.io.Text;
/**
* Field-aware Factorization Machines.
- *
+ *
* @link https://www.csie.ntu.edu.tw/~cjlin/libffm/
* @since v0.5-rc.1
*/
@@ -70,7 +70,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
private int _numFields;
// ----------------------------------------
- private transient FFMStringFeatureMapModel _ffmModel;
+ protected transient FFMStringFeatureMapModel _ffmModel;
private transient IntArrayList _fieldList;
@Nullable
@@ -85,12 +85,13 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
Options opts = super.getOptions();
opts.addOption("w0", "global_bias", false,
"Whether to include global bias term w0 [default: OFF]");
- opts.addOption("disable_wi", "no_coeff", false,
- "Not to include linear term [default: OFF]");
+ opts.addOption("enable_wi", "linear_term", false, "Include linear term [default: OFF]");
+ opts.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization");
// feature hashing
opts.addOption("feature_hashing", true,
"The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1.");
- opts.addOption("num_fields", true, "The number of fields [default: 256]");
+ opts.addOption("num_fields", true,
+ "The number of fields [default: " + Feature.DEFAULT_NUM_FIELDS + "]");
// optimizer
opts.addOption("opt", "optimizer", true,
"Gradient Descent optimizer [default: ftrl, adagrad, sgd]");
@@ -98,11 +99,11 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]");
// FTRL
opts.addOption("alpha", "alphaFTRL", true,
- "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.2]");
+ "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.5]");
opts.addOption("beta", "betaFTRL", true,
"Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]");
opts.addOption("l1", "lambda1", true,
- "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.001]");
+ "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.0002]");
opts.addOption("l2", "lambda2", true,
"L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]");
return opts;
@@ -180,13 +181,12 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
}
@Override
- public void train(@Nonnull final Feature[] x, final double y,
- final boolean adaptiveRegularization) throws HiveException {
- _ffmModel.check(x);
- try {
- trainTheta(x, y);
- } catch (Exception ex) {
- throw new HiveException("Exception caused in the " + _t + "-th call of train()", ex);
+ protected void processValidationSample(@Nonnull final Feature[] x, final double y)
+ throws HiveException {
+ if (_earlyStopping) {
+ double p = _model.predict(x);
+ double loss = _lossFunction.loss(p, y);
+ _validationState.incrLoss(loss);
}
}
@@ -292,7 +292,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
forward(forwardObjs);
final Entry entryW = new Entry(_ffmModel._buf, 1);
- final Entry entryV = new Entry(_ffmModel._buf, _ffmModel._factor);
+ final Entry entryV = new Entry(_ffmModel._buf, factors);
final float[] Vf = new float[factors];
for (Int2LongMap.Entry e : Fastutil.fastIterable(_ffmModel._map)) {
@@ -303,7 +303,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
final long offset = e.getLongValue();
if (Entry.isEntryW(i)) {// set Wi
entryW.setOffset(offset);
- float w = entryV.getW();
+ float w = entryW.getW();
if (w == 0.f) {
continue; // skip w_i=0
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
index 3f959e5..c46e470 100644
--- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
@@ -55,6 +55,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
private RowProcessor _proc;
private int _numFields;
private int _numFeatures;
+ private boolean _l2norm;
public FeaturePairsUDTF() {}
@@ -69,7 +70,9 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
opts.addOption("feature_hashing", true,
"The number of bits for feature hashing in range [18,31]. [default: -1] No feature hashing for -1.");
- opts.addOption("num_fields", true, "The number of fields [default:1024]");
+ opts.addOption("num_fields", true,
+ "The number of fields [default: " + Feature.DEFAULT_NUM_FIELDS + "]");
+ opts.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization");
return opts;
}
@@ -104,6 +107,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
throw new UDFArgumentException(
"-num_fields MUST be greater than 1: " + _numFields);
}
+ this._l2norm = !cl.hasOption("disable_norm");
} else {
throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0));
}
@@ -285,6 +289,10 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
this._features =
Feature.parseFFMFeatures(arg, fvOI, _features, _numFeatures, _numFields);
+ if (_l2norm) {
+ Feature.l2normalize(_features);
+ }
+
// W0
f0.set(0);
forward[1] = null;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java b/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java
index 45ef97d..e5de329 100644
--- a/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java
+++ b/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java
@@ -56,6 +56,11 @@ public final class L1NormalizationUDF extends UDF {
float v = Float.parseFloat(ft[1]);
weights[i] = v;
absoluteSum += Math.abs(v);
+ } else if (ftlen == 3) {
+ features[i] = ft[0] + ':' + ft[1];
+ float v = Float.parseFloat(ft[2]);
+ weights[i] = v;
+ absoluteSum += Math.abs(v);
} else {
throw new HiveException("Invalid feature value representation: " + s);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java b/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java
index 9cf315c..fa70f10 100644
--- a/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java
+++ b/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java
@@ -59,6 +59,11 @@ public final class L2NormalizationUDF extends UDF {
float v = Float.parseFloat(ft[1]);
weights[i] = v;
squaredSum += (v * v);
+ } else if (ftlen == 3) {
+ features[i] = ft[0] + ':' + ft[1];
+ float v = Float.parseFloat(ft[2]);
+ weights[i] = v;
+ squaredSum += (v * v);
} else {
throw new HiveException("Invalid feature value representation: " + s);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
index 76d52ab..23d9b63 100644
--- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
@@ -137,8 +137,12 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
@Override
protected Options getOptions() {
Options opts = new Options();
- opts.addOption("k", "factor", true, "The number of latent factor [default: 10]");
- opts.addOption("iter", "iterations", true, "The number of iterations [default: 30]");
+ opts.addOption("k", "factor", true,
+ "The number of latent factor [default: 10] Alias for `-factors`");
+ opts.addOption("f", "factors", true, "The number of latent factor [default: 10]");
+ opts.addOption("iters", "iterations", true, "The number of iterations [default: 30]");
+ opts.addOption("iter", true,
+ "The number of iterations [default: 30] Alias for `-iterations");
opts.addOption("loss", "loss_function", true,
"Loss function [default: lnLogistic, logistic, sigmoid]");
// initialization
@@ -191,8 +195,16 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
String rawArgs = HiveUtils.getConstString(argOIs[3]);
cl = parseOptions(rawArgs);
- this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor);
- this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations);
+ if (cl.hasOption("factor")) {
+ this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor);
+ } else {
+ this.factor = Primitives.parseInt(cl.getOptionValue("factors"), factor);
+ }
+ if (cl.hasOption("iter")) {
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iter"), iterations);
+ } else {
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations);
+ }
if (iterations < 1) {
throw new UDFArgumentException(
"'-iterations' must be greater than or equals to 1: " + iterations);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
index 537706e..9d7e1d1 100644
--- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
@@ -106,7 +106,9 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
@Override
protected Options getOptions() {
Options opts = new Options();
- opts.addOption("k", "factor", true, "The number of latent factor [default: 10]");
+ opts.addOption("k", "factor", true, "The number of latent factor [default: 10] "
+ + " Note this is alias for `factors` option.");
+ opts.addOption("f", "factors", true, "The number of latent factor [default: 10]");
opts.addOption("r", "lambda", true, "The regularization factor [default: 0.03]");
opts.addOption("mu", "mean_rating", true, "The mean rating [default: 0.0]");
opts.addOption("update_mean", "update_mu", false,
@@ -117,7 +119,9 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
"The maximum initial value in the rank matrix [default: 1.0]");
opts.addOption("min_init_stddev", true,
"The minimum standard deviation of initial rank matrix [default: 0.1]");
- opts.addOption("iter", "iterations", true, "The number of iterations [default: 1]");
+ opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
+ opts.addOption("iter", true,
+ "The number of iterations [default: 1] Alias for `-iterations`");
opts.addOption("disable_cv", "disable_cvtest", false,
"Whether to disable convergence check [default: enabled]");
opts.addOption("cv_rate", "convergence_rate", true,
@@ -138,14 +142,22 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
if (argOIs.length >= 4) {
String rawArgs = HiveUtils.getConstString(argOIs[3]);
cl = parseOptions(rawArgs);
- this.factor = Primitives.parseInt(cl.getOptionValue("factor"), 10);
+ if (cl.hasOption("factors")) {
+ this.factor = Primitives.parseInt(cl.getOptionValue("factors"), 10);
+ } else {
+ this.factor = Primitives.parseInt(cl.getOptionValue("factor"), 10);
+ }
this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 0.03f);
this.meanRating = Primitives.parseFloat(cl.getOptionValue("mu"), 0.f);
this.updateMeanRating = cl.hasOption("update_mean");
rankInitOpt = cl.getOptionValue("rankinit");
maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f);
initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
- this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
+ if (cl.hasOption("iter")) {
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iter"), 1);
+ } else {
+ this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
+ }
if (iterations < 1) {
throw new UDFArgumentException(
"'-iterations' must be greater than or equal to 1: " + iterations);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java
index 95c97dc..ca85cee 100644
--- a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java
+++ b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java
@@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.io.LongWritable;
-@Description(name = "rownum", value = "_FUNC_() - Returns a generated row number `sprintf(`%d%04d`,sequence,taskId)` in long",
+@Description(name = "rownum",
+ value = "_FUNC_() - Returns a generated row number `sprintf(`%d%04d`,sequence,taskId)` in long",
extended = "SELECT rownum() as rownum, xxx from ...")
@UDFType(deterministic = false, stateful = true)
public final class RowNumberUDF extends UDF {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index 7d43da1..ab3be9a 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -24,6 +24,9 @@ public final class Primitives {
public static final int INT_BYTES = Integer.SIZE / Byte.SIZE;
public static final int DOUBLE_BYTES = Double.SIZE / Byte.SIZE;
+ public static final Byte TRUE_BYTE = 1;
+ public static final Byte FALSE_BYTE = 0;
+
private Primitives() {}
public static short parseShort(final String s, final short defaultValue) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/utils/lang/SizeOf.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/SizeOf.java b/core/src/main/java/hivemall/utils/lang/SizeOf.java
index 9e0ef4c..08cf664 100644
--- a/core/src/main/java/hivemall/utils/lang/SizeOf.java
+++ b/core/src/main/java/hivemall/utils/lang/SizeOf.java
@@ -29,4 +29,5 @@ public final class SizeOf {
public static final int CHAR = Character.SIZE / Byte.SIZE;
private SizeOf() {}
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
index 64da212..b6b83c5 100644
--- a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
@@ -92,6 +92,141 @@ public class FactorizationMachineUDTFTest {
}
@Test
+ public void testAdaptiveRegularization() throws HiveException, IOException {
+ println("Adaptive regularization test");
+
+ final String options = "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.01 -seed 31 ";
+
+ FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
+ ObjectInspector[] argOIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, options)};
+
+ udtf.initialize(argOIs);
+
+ BufferedReader data = readFile("5107786.txt.gz");
+ List<List<String>> featureVectors = new ArrayList<>();
+ List<Double> ys = new ArrayList<>();
+ String line = data.readLine();
+ while (line != null) {
+ StringTokenizer tokenizer = new StringTokenizer(line, " ");
+ double y = Double.parseDouble(tokenizer.nextToken());
+ List<String> features = new ArrayList<String>();
+ while (tokenizer.hasMoreTokens()) {
+ String f = tokenizer.nextToken();
+ features.add(f);
+ }
+ udtf.process(new Object[] {features, y});
+ featureVectors.add(features);
+ ys.add(y);
+ line = data.readLine();
+ }
+ udtf.finalizeTraining();
+ data.close();
+
+ double loss = udtf._cvState.getAverageLoss(featureVectors.size());
+ println("Average loss without adaptive regularization: " + loss);
+
+ // train with adaptive regularization
+ udtf = new FactorizationMachineUDTF();
+ argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ options + "-adaptive_regularization -validation_threshold 1");
+ udtf.initialize(argOIs);
+ udtf.initModel(udtf._params);
+ for (int i = 0, n = featureVectors.size(); i < n; i++) {
+ udtf.process(new Object[] {featureVectors.get(i), ys.get(i)});
+ }
+ udtf.finalizeTraining();
+
+ double loss_adareg = udtf._cvState.getAverageLoss(featureVectors.size());
+ println("Average loss with adaptive regularization: " + loss_adareg);
+ Assert.assertTrue("Adaptive regularization should achieve lower loss", loss > loss_adareg);
+ }
+
+ @Test
+ public void testEarlyStopping() throws HiveException, IOException {
+ println("Early stopping test");
+
+ int iters = 20;
+
+ // train with 20 iterations
+ FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
+ ObjectInspector[] argOIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.002 -seed 31 -iters " + iters
+ + " -early_stopping -validation_threshold 1 -disable_cv")};
+
+ udtf.initialize(argOIs);
+
+ BufferedReader data = readFile("5107786.txt.gz");
+ List<List<String>> featureVectors = new ArrayList<>();
+ List<Double> ys = new ArrayList<>();
+ String line = data.readLine();
+ while (line != null) {
+ StringTokenizer tokenizer = new StringTokenizer(line, " ");
+ double y = Double.parseDouble(tokenizer.nextToken());
+ List<String> features = new ArrayList<String>();
+ while (tokenizer.hasMoreTokens()) {
+ String f = tokenizer.nextToken();
+ features.add(f);
+ }
+ udtf.process(new Object[] {features, y});
+ featureVectors.add(features);
+ ys.add(y);
+ line = data.readLine();
+ }
+ udtf.finalizeTraining();
+ data.close();
+
+ double loss = udtf._validationState.getAverageLoss(featureVectors.size());
+ Assert.assertTrue(
+ "Training seems to be failed because average loss is greater than 0.1: " + loss,
+ loss <= 0.1);
+
+ Assert.assertNotNull("Early stopping validation has not been conducted",
+ udtf._validationState);
+ println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of "
+ + iters);
+ Assert.assertNotEquals("Early stopping did not happen", iters,
+ udtf._validationState.getCurrentIteration());
+
+ // store the best state achieved by early stopping
+ iters = udtf._validationState.getCurrentIteration() - 2; // best loss was at (N-2)-th iter
+ double cumulativeLoss = udtf._validationState.getCumulativeLoss();
+ println("Cumulative loss: " + cumulativeLoss);
+
+ // train with the number of early-stopped iterations
+ udtf = new FactorizationMachineUDTF();
+ argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.002 -seed 31 -iters " + iters
+ + " -early_stopping -validation_threshold 1 -disable_cv");
+ udtf.initialize(argOIs);
+ udtf.initModel(udtf._params);
+ for (int i = 0, n = featureVectors.size(); i < n; i++) {
+ udtf.process(new Object[] {featureVectors.get(i), ys.get(i)});
+ }
+ udtf.finalizeTraining();
+
+ println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of "
+ + iters);
+ Assert.assertEquals("Training finished earlier than expected", iters,
+ udtf._validationState.getCurrentIteration());
+
+ println("Cumulative loss: " + udtf._validationState.getCumulativeLoss());
+ Assert.assertTrue("Cumulative loss should be better than " + cumulativeLoss,
+ cumulativeLoss > udtf._validationState.getCumulativeLoss());
+ }
+
+ @Test
public void testEnableL2Norm() throws HiveException, IOException {
FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
ObjectInspector[] argOIs = new ObjectInspector[] {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 5b7aa8f..67040a1 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -44,59 +44,61 @@ import org.junit.Test;
public class FieldAwareFactorizationMachineUDTFTest {
private static final boolean DEBUG = false;
- private static final int ITERATIONS = 50;
- private static final int MAX_LINES = 200;
// ----------------------------------------------------
// bigdata.tr.txt
@Test
public void testSGD() throws HiveException, IOException {
- runIterations("Pure SGD test", "bigdata.tr.txt.gz",
- "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f);
+ run("Pure SGD test", "bigdata.tr.txt.gz",
+ "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters 20 -seed 43",
+ 0.30f);
}
@Test
public void testAdaGrad() throws HiveException, IOException {
- runIterations("AdaGrad test", "bigdata.tr.txt.gz",
- "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f);
+ run("AdaGrad test", "bigdata.tr.txt.gz",
+ "-opt adagrad -linear_term -classification -factors 10 -w0 -eta 0.4 -iters 30 -seed 43",
+ 0.30f);
}
@Test
public void testAdaGradNoCoeff() throws HiveException, IOException {
- runIterations("AdaGrad No Coeff test", "bigdata.tr.txt.gz",
- "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f);
+ run("AdaGrad No Coeff test", "bigdata.tr.txt.gz",
+ "-opt adagrad -classification -factors 10 -w0 -eta 0.4 -iters 30 -seed 43", 0.30f);
}
@Test
public void testFTRL() throws HiveException, IOException {
- runIterations("FTRL test", "bigdata.tr.txt.gz",
- "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f);
+ run("FTRL test", "bigdata.tr.txt.gz",
+ "-opt ftrl -linear_term -classification -factors 10 -w0 -alphaFTRL 10.0 -seed 43",
+ 0.30f);
}
@Test
public void testFTRLNoCoeff() throws HiveException, IOException {
- runIterations("FTRL Coeff test", "bigdata.tr.txt.gz",
- "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f);
+ run("FTRL Coeff test", "bigdata.tr.txt.gz",
+ "-opt ftrl -classification -factors 10 -w0 -alphaFTRL 10.0 -seed 43", 0.30f);
}
// ----------------------------------------------------
// https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz
@Test
- public void testSample() throws IOException, HiveException {
+ public void testSampleDisableNorm() throws IOException, HiveException {
System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2");
run("[Sample.ffm] default option",
"https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz",
- "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.01f);
+ "-disable_norm -linear_term -classification -factors 2 -feature_hashing 20 -seed 43",
+ 0.01f);
}
- // TODO @Test
- public void testSampleEnableNorm() throws IOException, HiveException {
+ @Test
+ public void testSample() throws IOException, HiveException {
System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2");
run("[Sample.ffm] default option",
"https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz",
- "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43 -enable_norm",
+ "-linear_term -classification -factors 2 -alphaFTRL 10.0 -feature_hashing 20 -seed 43",
0.01f);
}
@@ -161,66 +163,112 @@ public class FieldAwareFactorizationMachineUDTFTest {
avgLoss < lossThreshold);
}
- private static void runIterations(String testName, String testFile, String testOptions,
- float lossThreshold) throws IOException, HiveException {
- println(testName);
+ @Test
+ public void testEarlyStopping() throws HiveException, IOException {
+ println("Early stopping");
+
+ int iters = 20;
FieldAwareFactorizationMachineUDTF udtf = new FieldAwareFactorizationMachineUDTF();
- ObjectInspector[] argOIs =
- new ObjectInspector[] {
- ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.javaStringObjectInspector),
- PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
- ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- testOptions)};
+ ObjectInspector[] argOIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters " + iters
+ + " -early_stopping -validation_threshold 1 -disable_cv -seed 43")};
udtf.initialize(argOIs);
- FieldAwareFactorizationMachineModel model = udtf.initModel(udtf._params);
- Assert.assertTrue("Actual class: " + model.getClass().getName(),
- model instanceof FFMStringFeatureMapModel);
- double loss = 0.d;
- double cumul = 0.d;
- for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) {
- BufferedReader data = readFile(testFile);
- loss = udtf._cvState.getCumulativeLoss();
- int lines = 0;
- for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, ++lines) {
- //gather features in current line
- final String input = data.readLine();
- if (input == null) {
- break;
- }
- String[] featureStrings = input.split(" ");
+ BufferedReader data = readFile("bigdata.tr.txt.gz");
+ List<List<String>> featureVectors = new ArrayList<>();
+ List<Double> ys = new ArrayList<>();
+ while (true) {
+ //gather features in current line
+ final String input = data.readLine();
+ if (input == null) {
+ break;
+ }
+ String[] featureStrings = input.split(" ");
- double y = Double.parseDouble(featureStrings[0]);
- if (y == 0) {
- y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1}
- }
+ double y = Double.parseDouble(featureStrings[0]);
+ if (y == 0) {
+ y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1}
+ }
+ ys.add(y);
- final List<String> features = new ArrayList<String>(featureStrings.length - 1);
- for (int j = 1; j < featureStrings.length; ++j) {
- String fj = featureStrings[j];
- String[] splitted = fj.split(":");
- Assert.assertEquals(3, splitted.length);
- String indexStr = splitted[1];
- String f = fj;
- if (NumberUtils.isDigits(indexStr)) {
- int index = Integer.parseInt(indexStr) + 1; // avoid 0 index
- f = splitted[0] + ':' + index + ':' + splitted[2];
- }
- features.add(f);
+ final List<String> features = new ArrayList<String>(featureStrings.length - 1);
+ for (int j = 1; j < featureStrings.length; ++j) {
+ String fj = featureStrings[j];
+ String[] splitted = fj.split(":");
+ Assert.assertEquals(3, splitted.length);
+ String indexStr = splitted[1];
+ String f = fj;
+ if (NumberUtils.isDigits(indexStr)) {
+ int index = Integer.parseInt(indexStr) + 1; // avoid 0 index
+ f = splitted[0] + ':' + index + ':' + splitted[2];
}
- udtf.process(new Object[] {features, y});
+ features.add(f);
}
- cumul = udtf._cvState.getCumulativeLoss();
- loss = (cumul - loss) / lines;
- println(trainingIteration + " " + loss + " " + cumul / (trainingIteration * lines));
- data.close();
+ featureVectors.add(features);
+
+ udtf.process(new Object[] {features, y});
}
- println("model size=" + udtf._model.getSize());
- Assert.assertTrue("Last loss was greater than expected: " + loss, loss < lossThreshold);
+ udtf.finalizeTraining();
+ data.close();
+
+ double loss = udtf._validationState.getAverageLoss(featureVectors.size());
+ Assert.assertTrue(
+ "Training seems to be failed because average loss is greater than 0.6: " + loss,
+ loss <= 0.6);
+
+ Assert.assertNotNull("Early stopping validation has not been conducted",
+ udtf._validationState);
+ println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of "
+ + iters);
+ Assert.assertNotEquals("Early stopping did not happen", iters,
+ udtf._validationState.getCurrentIteration());
+
+ // store the best state achieved by early stopping
+ iters = udtf._validationState.getCurrentIteration() - 2; // best loss was at (N-2)-th iter
+ double cumulativeLoss = udtf._validationState.getCumulativeLoss();
+ println("Cumulative loss: " + cumulativeLoss);
+
+ // train with the number of early-stopped iterations
+ udtf = new FieldAwareFactorizationMachineUDTF();
+ argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters " + iters
+ + " -early_stopping -validation_threshold 1 -disable_cv -seed 43");
+ udtf.initialize(argOIs);
+ udtf.initModel(udtf._params);
+ for (int i = 0, n = featureVectors.size(); i < n; i++) {
+ udtf.process(new Object[] {featureVectors.get(i), ys.get(i)});
+ }
+ udtf.finalizeTraining();
+
+ println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of "
+ + iters);
+ Assert.assertEquals("Training finished earlier than expected", iters,
+ udtf._validationState.getCurrentIteration());
+
+ println("Cumulative loss: " + udtf._validationState.getCumulativeLoss());
+ Assert.assertTrue("Cumulative loss should be better than " + cumulativeLoss,
+ cumulativeLoss > udtf._validationState.getCumulativeLoss());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testUnsupportedAdaptiveRegularizationOption() throws Exception {
+ TestUtils.testGenericUDTFSerialization(FieldAwareFactorizationMachineUDTF.class,
+ new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-seed 43 -adaptive_regularization")},
+ new Object[][] {{Arrays.asList("0:1:-2", "1:2:-1"), 1.0}});
}
@Test
@@ -231,8 +279,7 @@ public class FieldAwareFactorizationMachineUDTFTest {
PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
ObjectInspectorUtils.getConstantObjectInspector(
- PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- "-opt sgd -classification -factors 10 -w0 -seed 43")},
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 43")},
new Object[][] {{Arrays.asList("0:1:-2", "1:2:-1"), 1.0}});
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java b/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java
index 7d997f7..bfb37fc 100644
--- a/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java
+++ b/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java
@@ -59,6 +59,12 @@ public class L1NormalizationUDFTest {
WritableUtils.val(new String[] {"aaa:" + normalized[0], "bbb:" + normalized[1]}),
udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"})));
+ normalized = MathUtils.l1normalize(new float[] {1.0f, 2.0f, 3.0f});
+ assertEquals(
+ WritableUtils.val(new String[] {"1:123:" + normalized[0], "2:456:" + normalized[1],
+ "3:789:" + normalized[2]}),
+ udf.evaluate(WritableUtils.val(new String[] {"1:123:1", "2:456:2", "3:789:3"})));
+
List<Text> expected = udf.evaluate(WritableUtils.val(new String[] {"bbb:-0.5", "aaa:1.0"}));
Collections.sort(expected);
List<Text> actual = udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"}));
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java b/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java
index 30e2aba..393a9d2 100644
--- a/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java
+++ b/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java
@@ -59,6 +59,12 @@ public class L2NormalizationUDFTest {
WritableUtils.val(new String[] {"aaa:" + 1.0f / l2norm, "bbb:" + -0.5f / l2norm}),
udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"})));
+ l2norm = MathUtils.l2norm(new float[] {1.0f, 2.0f, 3.0f});
+ assertEquals(
+ WritableUtils.val(new String[] {"1:123:" + 1.0f / l2norm, "2:456:" + 2.0f / l2norm,
+ "3:789:" + 3.0f / l2norm}),
+ udf.evaluate(WritableUtils.val(new String[] {"1:123:1", "2:456:2", "3:789:3"})));
+
List<Text> expected = udf.evaluate(WritableUtils.val(new String[] {"bbb:-0.5", "aaa:1.0"}));
Collections.sort(expected);
List<Text> actual = udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"}));
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 56e416f..155a221 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -111,6 +111,9 @@
* [Kaggle Titanic Tutorial](binaryclass/titanic_rf.md)
+* [Criteo Tutorial](binaryclass/criteo.md)
+ * [Data preparation](binaryclass/criteo_dataset.md)
+ * [Field-Aware Factorization Machines](binaryclass/criteo_ffm.md)
## Part VII - Multiclass Classification
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/binaryclass/criteo.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/criteo.md b/docs/gitbook/binaryclass/criteo.md
new file mode 100644
index 0000000..3ad5f81
--- /dev/null
+++ b/docs/gitbook/binaryclass/criteo.md
@@ -0,0 +1,20 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+This tutorial tackles [Kaggle Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge).
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/binaryclass/criteo_dataset.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/criteo_dataset.md b/docs/gitbook/binaryclass/criteo_dataset.md
new file mode 100644
index 0000000..c4c12ea
--- /dev/null
+++ b/docs/gitbook/binaryclass/criteo_dataset.md
@@ -0,0 +1,97 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+<!-- toc -->
+
+# Download data
+
+Get dataset of [Kaggle Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge) from one of the following sources:
+
+1. [Original competition data](http://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/) (by Criteo Labs) [~20GB]
+2. [Subset of the original competition data](http://labs.criteo.com/2014/02/dataset/) (by Criteo Labs) [~30MB]
+3. [Tiny sample data](https://github.com/guestwalk/kaggle-2014-criteo) (by the winners of the competition) [~20bytes]
+
+It should be noted that you must accept and agree with **CRITEO LABS DATA TERM OF USE** before downloading the data.
+
+# Convert data into CSV format
+
+Here, you can use a script prepared by one of the Hivemall PPMC members: **[takuti/criteo-ffm](https://github.com/takuti/criteo-ffm)**.
+
+Clone the repository:
+
+```sh
+git clone git@github.com:takuti/criteo-ffm.git
+cd criteo-ffm
+```
+
+A script [`data.sh`](https://github.com/takuti/criteo-ffm/blob/master/data.sh) downloads the original data and converts them into CSV format:
+
+```sh
+./data.sh # downloads the original data and generates `train.csv` and `test.csv`
+ln -s train.csv tr.csv
+ln -s test.csv te.csv
+```
+
+Or, since the original data is very huge, starting from the tiny sample data bundled into the repository would be better:
+
+```sh
+ln -s train.tiny.csv tr.csv
+ln -s test.tiny.csv te.csv
+```
+
+# Create tables
+
+Load the CSV files to Hive tables as:
+
+```sh
+hadoop fs -put tr.csv /criteo/train
+hadoop fs -put te.csv /criteo/test
+```
+
+```sql
+CREATE DATABASE IF NOT EXISTS criteo;
+use criteo;
+```
+
+```sql
+DROP TABLE IF EXISTS train;
+CREATE EXTERNAL TABLE train (
+ id bigint,
+ label int,
+ -- quantitative features
+ i1 int,i2 int,i3 int,i4 int,i5 int,i6 int,i7 int,i8 int,i9 int,i10 int,i11 int,i12 int,i13 int,
+ -- categorical features
+ c1 string,c2 string,c3 string,c4 string,c5 string,c6 string,c7 string,c8 string,c9 string,c10 string,c11 string,c12 string,c13 string,c14 string,c15 string,c16 string,c17 string,c18 string,c19 string,c20 string,c21 string,c22 string,c23 string,c24 string,c25 string,c26 string
+) ROW FORMAT
+DELIMITED FIELDS TERMINATED BY ','
+STORED AS TEXTFILE LOCATION '/criteo/train';
+```
+
+```sql
+DROP TABLE IF EXISTS test;
+CREATE EXTERNAL TABLE test (
+ label int,
+ -- quantitative features
+ i1 int,i2 int,i3 int,i4 int,i5 int,i6 int,i7 int,i8 int,i9 int,i10 int,i11 int,i12 int,i13 int,
+ -- categorical features
+ c1 string,c2 string,c3 string,c4 string,c5 string,c6 string,c7 string,c8 string,c9 string,c10 string,c11 string,c12 string,c13 string,c14 string,c15 string,c16 string,c17 string,c18 string,c19 string,c20 string,c21 string,c22 string,c23 string,c24 string,c25 string,c26 string
+) ROW FORMAT
+DELIMITED FIELDS TERMINATED BY ','
+STORED AS TEXTFILE LOCATION '/criteo/test';
+```
\ No newline at end of file