You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/14 17:51:42 UTC
[4/4] incubator-hivemall git commit: [HIVEMALL-101] refactored the
previous commit
[HIVEMALL-101] refactored the previous commit
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/5e27993b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/5e27993b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/5e27993b
Branch: refs/heads/master
Commit: 5e27993b6617a9bd817dc906fed7ca79e9bdc6ad
Parents: 3848ea6
Author: Makoto Yui <my...@apache.org>
Authored: Thu Jun 15 02:50:27 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Jun 15 02:50:27 2017 +0900
----------------------------------------------------------------------
.../java/hivemall/GeneralLearnerBaseUDTF.java | 31 ++---
.../src/main/java/hivemall/LearnerBaseUDTF.java | 6 +-
core/src/main/java/hivemall/UDFWithOptions.java | 6 +-
.../src/main/java/hivemall/UDTFWithOptions.java | 5 +-
.../classifier/BinaryOnlineClassifierUDTF.java | 6 +-
.../classifier/GeneralClassifierUDTF.java | 13 +-
core/src/main/java/hivemall/fm/Feature.java | 2 +-
.../main/java/hivemall/model/NewDenseModel.java | 8 +-
.../model/NewSpaceEfficientDenseModel.java | 19 ++-
.../java/hivemall/model/NewSparseModel.java | 4 +-
.../main/java/hivemall/model/SparseModel.java | 4 +-
.../hivemall/model/WeightValueWithClock.java | 8 +-
.../optimizer/DenseOptimizerFactory.java | 62 +++++-----
.../java/hivemall/optimizer/EtaEstimator.java | 14 +--
.../java/hivemall/optimizer/LossFunctions.java | 75 ++++++-----
.../main/java/hivemall/optimizer/Optimizer.java | 123 +++++++++----------
.../hivemall/optimizer/OptimizerOptions.java | 17 +--
.../java/hivemall/optimizer/Regularization.java | 13 +-
.../optimizer/SparseOptimizerFactory.java | 107 ++++++++--------
.../regression/GeneralRegressionUDTF.java | 5 +-
.../utils/collections/maps/IntOpenHashMap.java | 4 +-
.../classifier/GeneralClassifierUDTFTest.java | 10 +-
.../regression/GeneralRegressionUDTFTest.java | 6 +-
23 files changed, 276 insertions(+), 272 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index e798fdf..34c7ec9 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -70,14 +70,15 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
private Optimizer optimizer;
private LossFunction lossFunction;
- protected PredictionModel model;
- protected int count;
+ private PredictionModel model;
+ private long count;
// The accumulated delta of each weight values.
- protected transient Map<Object, FloatAccumulator> accumulated;
- protected int sampled;
+ @Nullable
+ private transient Map<Object, FloatAccumulator> accumulated;
+ private int sampled;
- private float cumLoss;
+ private double cumLoss;
public GeneralLearnerBaseUDTF() {
this(true);
@@ -122,12 +123,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
try {
this.optimizer = createOptimizer(optimizerOptions);
} catch (Throwable e) {
- throw new UDFArgumentException(e.getMessage());
+ throw new UDFArgumentException(e);
}
- this.count = 0;
+ this.count = 0L;
this.sampled = 0;
- this.cumLoss = 0.f;
+ this.cumLoss = 0.d;
return getReturnOI(featureOutputOI);
}
@@ -160,7 +161,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
return cl;
}
- protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
+ @Nonnull
+ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg)
throws UDFArgumentException {
this.featureListOI = (ListObjectInspector) arg;
ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
@@ -169,7 +171,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
}
- protected StructObjectInspector getReturnOI(ObjectInspector featureOutputOI) {
+ @Nonnull
+ protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
@@ -241,7 +244,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
final float v = f.getValueAsFloat();
float old_w = model.getWeight(k);
- if (old_w != 0f) {
+ if (old_w != 0.f) {
score += (old_w * v);
}
}
@@ -302,7 +305,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
this.sampled = 0;
}
- protected void onlineUpdate(@Nonnull final FeatureValue[] features, float dloss) {
+ protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float dloss) {
for (FeatureValue f : features) {
Object feature = f.getFeature();
float xi = f.getValueAsFloat();
@@ -368,13 +371,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
}
@VisibleForTesting
- public float getCumulativeLoss() {
+ public double getCumulativeLoss() {
return cumLoss;
}
@VisibleForTesting
public void resetCumulativeLoss() {
- this.cumLoss = 0.f;
+ this.cumLoss = 0.d;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/LearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java
index bb15bb3..fdb22f8 100644
--- a/core/src/main/java/hivemall/LearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java
@@ -255,9 +255,9 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
}
}
- protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) {
- assert (connectURIs != null);
- assert (model != null);
+ @Nonnull
+ protected MixClient configureMixClient(@Nonnull String connectURIs, @Nullable String label,
+ @Nonnull PredictionModel model) {
String jobId = (mixSessionName == null) ? MixClient.DUMMY_JOB_ID : mixSessionName;
if (label != null) {
jobId = jobId + '-' + label;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/UDFWithOptions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/UDFWithOptions.java b/core/src/main/java/hivemall/UDFWithOptions.java
index 3aaf9ef..9908cd9 100644
--- a/core/src/main/java/hivemall/UDFWithOptions.java
+++ b/core/src/main/java/hivemall/UDFWithOptions.java
@@ -77,9 +77,12 @@ public abstract class UDFWithOptions extends GenericUDF {
}
}
+ @Nonnull
protected abstract Options getOptions();
- protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException {
+ @Nonnull
+ protected final CommandLine parseOptions(@Nonnull String optionValue)
+ throws UDFArgumentException {
String[] args = optionValue.split("\\s+");
Options opts = getOptions();
opts.addOption("help", false, "Show function help");
@@ -109,6 +112,7 @@ public abstract class UDFWithOptions extends GenericUDF {
return cl;
}
+ @Nonnull
protected abstract CommandLine processOptions(@Nonnull String optionValue)
throws UDFArgumentException;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/UDTFWithOptions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/UDTFWithOptions.java b/core/src/main/java/hivemall/UDTFWithOptions.java
index 1556a4f..39ab233 100644
--- a/core/src/main/java/hivemall/UDTFWithOptions.java
+++ b/core/src/main/java/hivemall/UDTFWithOptions.java
@@ -123,8 +123,9 @@ public abstract class UDTFWithOptions extends GenericUDTF {
protected abstract CommandLine processOptions(ObjectInspector[] argOIs)
throws UDFArgumentException;
- protected final List<FeatureValue> parseFeatures(final List<?> features,
- final ObjectInspector featureInspector, final boolean parseFeature) {
+ @Nonnull
+ protected final List<FeatureValue> parseFeatures(@Nonnull final List<?> features,
+ @Nonnull final ObjectInspector featureInspector, final boolean parseFeature) {
final int numFeatures = features.size();
if (numFeatures == 0) {
return Collections.emptyList();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index d25f254..2dcf521 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -167,8 +167,10 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
return featureVector;
}
- protected void checkLabelValue(int label) throws UDFArgumentException {
- assert (label == -1 || label == 0 || label == 1) : label;
+ protected void checkLabelValue(final int label) throws UDFArgumentException {
+ if (label != -1 && label != 0 && label != 1) {
+ throw new UDFArgumentException("Invalid label value for classification: + label");
+ }
}
@VisibleForTesting
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
index 753a498..d7cb539 100644
--- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
@@ -51,11 +51,18 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF {
}
@Override
- protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException {};
+ protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException {
+ if(!lossFunction.forBinaryClassification()) {
+ throw new UDFArgumentException("The loss function `" + lossFunction.getType()
+ + "` is not designed for binary classification");
+ }
+ }
@Override
- protected void checkTargetValue(float label) throws UDFArgumentException {
- assert (label == -1.f || label == 0.f || label == 1.f) : label;
+ protected void checkTargetValue(final float label) throws UDFArgumentException {
+ if (label != -1 && label != 0 && label != 1) {
+ throw new UDFArgumentException("Invalid label value for classification: + label");
+ }
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/fm/Feature.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java
index f2d977e..2966a02 100644
--- a/core/src/main/java/hivemall/fm/Feature.java
+++ b/core/src/main/java/hivemall/fm/Feature.java
@@ -262,7 +262,7 @@ public abstract class Feature {
if (asIntFeature) {
int index = parseFeatureIndex(indexStr);
probe.setFeatureIndex(index);
- probe.value = parseFeatureValue(valueStr);;
+ probe.value = parseFeatureValue(valueStr);
} else {
probe.setFeature(indexStr);
probe.value = parseFeatureValue(valueStr);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/model/NewDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java
index aab3c2b..b5db580 100644
--- a/core/src/main/java/hivemall/model/NewDenseModel.java
+++ b/core/src/main/java/hivemall/model/NewDenseModel.java
@@ -53,7 +53,7 @@ public final class NewDenseModel extends AbstractPredictionModel {
this.weights = new float[size];
if (withCovar) {
float[] covars = new float[size];
- Arrays.fill(covars, 1f);
+ Arrays.fill(covars, 1.f);
this.covars = covars;
} else {
this.covars = null;
@@ -99,8 +99,10 @@ public final class NewDenseModel extends AbstractPredictionModel {
int bits = MathUtils.bitsRequired(index);
int newSize = (1 << bits) + 1;
int oldSize = size;
- logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
- + bits + " bits)");
+ if (logger.isInfoEnabled()) {
+ logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
+ + bits + " bits)");
+ }
this.size = newSize;
this.weights = Arrays.copyOf(weights, newSize);
if (covars != null) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
index 1848529..0a473b4 100644
--- a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
+++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
@@ -18,7 +18,6 @@
*/
package hivemall.model;
-import hivemall.annotations.InternalAPI;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
@@ -105,12 +104,8 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
return HalfFloat.halfFloatToFloat(covars[i]);
}
- @InternalAPI
- private void _setWeight(final int i, final float v) {
- if (Math.abs(v) >= HalfFloat.MAX_FLOAT) {
- throw new IllegalArgumentException("Acceptable maximum weight is "
- + HalfFloat.MAX_FLOAT + ": " + v);
- }
+ private void setWeight(final int i, final float v) {
+ HalfFloat.checkRange(v);
weights[i] = HalfFloat.floatToHalfFloat(v);
}
@@ -159,7 +154,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
- _setWeight(i, weight);
+ setWeight(i, weight);
float covar = 1.f;
boolean hasCovar = value.hasCovariance();
if (hasCovar) {
@@ -185,7 +180,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return;
}
- _setWeight(i, 0.f);
+ setWeight(i, 0.f);
if (covars != null) {
setCovar(i, 1.f);
}
@@ -205,7 +200,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
public void setWeight(@Nonnull final Object feature, final float value) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
- _setWeight(i, value);
+ setWeight(i, value);
}
@Override
@@ -221,7 +216,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(@Nonnull final Object feature, final float weight, final short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
clocks[i] = clock;
deltaUpdates[i] = 0;
}
@@ -231,7 +226,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
final short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
setCovar(i, covar);
clocks[i] = clock;
deltaUpdates[i] = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/model/NewSparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java
index e312ae4..8326d22 100644
--- a/core/src/main/java/hivemall/model/NewSparseModel.java
+++ b/core/src/main/java/hivemall/model/NewSparseModel.java
@@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory;
public final class NewSparseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(NewSparseModel.class);
+ @Nonnull
private final OpenHashMap<Object, IWeightValue> weights;
private final boolean hasCovar;
private boolean clockEnabled;
@@ -80,9 +81,6 @@ public final class NewSparseModel extends AbstractPredictionModel {
@Override
public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) {
- assert (feature != null);
- assert (value != null);
-
final IWeightValue wrapperValue = wrapIfRequired(value);
if (clockEnabled && value.isTouched()) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index ec26552..cb8ab9f 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory;
public final class SparseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(SparseModel.class);
+ @Nonnull
private final OpenHashMap<Object, IWeightValue> weights;
private final boolean hasCovar;
private boolean clockEnabled;
@@ -76,9 +77,6 @@ public final class SparseModel extends AbstractPredictionModel {
@Override
public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) {
- assert (feature != null);
- assert (value != null);
-
final IWeightValue wrapperValue = wrapIfRequired(value);
if (clockEnabled && value.isTouched()) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/model/WeightValueWithClock.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java
index 524fa94..679c519 100644
--- a/core/src/main/java/hivemall/model/WeightValueWithClock.java
+++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java
@@ -276,6 +276,7 @@ public class WeightValueWithClock implements IWeightValue {
public void setSumOfGradients(float value) {
this.f2 = value;
}
+
@Override
public float getM() {
return f1;
@@ -324,11 +325,11 @@ public class WeightValueWithClock implements IWeightValue {
@Override
public float getFloatParams(@Nonnegative final int i) {
- if(i == 1) {
+ if (i == 1) {
return f1;
- } else if(i == 2) {
+ } else if (i == 2) {
return f2;
- } else if(i == 3) {
+ } else if (i == 3) {
return f3;
}
throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called");
@@ -363,6 +364,7 @@ public class WeightValueWithClock implements IWeightValue {
public void setSumOfGradients(float value) {
this.f3 = value;
}
+
@Override
public float getM() {
return f1;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
index 2bf030b..e273f91 100644
--- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
@@ -20,7 +20,6 @@ package hivemall.optimizer;
import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
-import hivemall.optimizer.Optimizer.OptimizerBase;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
@@ -34,37 +33,40 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
public final class DenseOptimizerFactory {
- private static final Log logger = LogFactory.getLog(DenseOptimizerFactory.class);
+ private static final Log LOG = LogFactory.getLog(DenseOptimizerFactory.class);
@Nonnull
public static Optimizer create(int ndims, @Nonnull Map<String, String> options) {
final String optimizerName = options.get("optimizer");
- if (optimizerName != null) {
- OptimizerBase optimizerImpl;
- if (optimizerName.toLowerCase().equals("sgd")) {
- optimizerImpl = new Optimizer.SGD(options);
- } else if (optimizerName.toLowerCase().equals("adadelta")) {
- optimizerImpl = new AdaDelta(ndims, options);
- } else if (optimizerName.toLowerCase().equals("adagrad")) {
- optimizerImpl = new AdaGrad(ndims, options);
- } else if (optimizerName.toLowerCase().equals("adam")) {
- optimizerImpl = new Adam(ndims, options);
- } else {
- throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
- }
-
- logger.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: "
- + options);
+ if (optimizerName == null) {
+ throw new IllegalArgumentException("`optimizer` not defined");
+ }
+ final Optimizer optimizerImpl;
+ if ("sgd".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new Optimizer.SGD(options);
+ } else if ("adadelta".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new AdaDelta(ndims, options);
+ } else if ("adagrad".equalsIgnoreCase(optimizerName)) {
// If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
- if (options.get("regularization") != null
- && options.get("regularization").toLowerCase().equals("rda")) {
- optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options);
+ if ("rda".equalsIgnoreCase(options.get("regularization"))) {
+ AdaGrad adagrad = new AdaGrad(ndims, options);
+ optimizerImpl = new AdagradRDA(ndims, adagrad, options);
+ } else {
+ optimizerImpl = new AdaGrad(ndims, options);
}
+ } else if ("adam".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new Adam(ndims, options);
+ } else {
+ throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
+ }
- return optimizerImpl;
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: "
+ + options);
}
- throw new IllegalArgumentException("`optimizer` not defined");
+
+ return optimizerImpl;
}
@NotThreadSafe
@@ -86,7 +88,7 @@ public final class DenseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weightValueReused.set(weight);
@@ -112,8 +114,9 @@ public final class DenseOptimizerFactory {
@NotThreadSafe
static final class AdaGrad extends Optimizer.AdaGrad {
+ @Nonnull
private final IWeightValue weightValueReused;
-
+ @Nonnull
private float[] sum_of_squared_gradients;
public AdaGrad(int ndims, Map<String, String> options) {
@@ -123,7 +126,7 @@ public final class DenseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weightValueReused.set(weight);
@@ -162,7 +165,7 @@ public final class DenseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weightValueReused.set(weight);
@@ -194,14 +197,15 @@ public final class DenseOptimizerFactory {
@Nonnull
private float[] sum_of_gradients;
- public AdagradRDA(int ndims, final OptimizerBase optimizerImpl, Map<String, String> options) {
+ public AdagradRDA(int ndims, @Nonnull Optimizer.AdaGrad optimizerImpl,
+ @Nonnull Map<String, String> options) {
super(optimizerImpl, options);
this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f);
this.sum_of_gradients = new float[ndims];
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weightValueReused.set(weight);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
index 17b39d1..1a4c07d 100644
--- a/core/src/main/java/hivemall/optimizer/EtaEstimator.java
+++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
@@ -161,14 +161,12 @@ public abstract class EtaEstimator {
@Nonnull
public static EtaEstimator get(@Nonnull final Map<String, String> options)
throws IllegalArgumentException {
+ final float eta0 = Primitives.parseFloat(options.get("eta0"), 0.1f);
+ final double power_t = Primitives.parseDouble(options.get("power_t"), 0.1d);
+
final String etaScheme = options.get("eta");
if (etaScheme == null) {
- return new InvscalingEtaEstimator(0.1f, 0.1d);
- }
-
- float eta0 = 0.1f;
- if (options.containsKey("eta0")) {
- eta0 = Float.parseFloat(options.get("eta0"));
+ return new InvscalingEtaEstimator(eta0, power_t);
}
if ("fixed".equalsIgnoreCase(etaScheme)) {
@@ -183,10 +181,6 @@ public abstract class EtaEstimator {
}
return new SimpleEtaEstimator(eta0, t);
} else if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme)) {
- double power_t = 0.1;
- if (options.containsKey("power_t")) {
- power_t = Double.parseDouble(options.get("power_t"));
- }
return new InvscalingEtaEstimator(eta0, power_t);
} else {
throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java
index 0dff4aa..a1ade3d 100644
--- a/core/src/main/java/hivemall/optimizer/LossFunctions.java
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@ -20,6 +20,9 @@ package hivemall.optimizer;
import hivemall.utils.math.MathUtils;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
/**
* @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions
*/
@@ -30,7 +33,8 @@ public final class LossFunctions {
SquaredHingeLoss, ModifiedHuberLoss
}
- public static LossFunction getLossFunction(String type) {
+ @Nonnull
+ public static LossFunction getLossFunction(@Nullable final String type) {
if ("SquaredLoss".equalsIgnoreCase(type)) {
return new SquaredLoss();
} else if ("QuantileLoss".equalsIgnoreCase(type)) {
@@ -41,7 +45,7 @@ public final class LossFunctions {
return new HuberLoss();
} else if ("HingeLoss".equalsIgnoreCase(type)) {
return new HingeLoss();
- } else if ("LogLoss".equalsIgnoreCase(type)) {
+ } else if ("LogLoss".equalsIgnoreCase(type) || "LogisticLoss".equalsIgnoreCase(type)) {
return new LogLoss();
} else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
return new SquaredHingeLoss();
@@ -51,7 +55,8 @@ public final class LossFunctions {
throw new IllegalArgumentException("Unsupported loss function name: " + type);
}
- public static LossFunction getLossFunction(LossType type) {
+ @Nonnull
+ public static LossFunction getLossFunction(@Nonnull final LossType type) {
switch (type) {
case SquaredLoss:
return new SquaredLoss();
@@ -100,6 +105,7 @@ public final class LossFunctions {
public boolean forRegression();
+ @Nonnull
public LossType getType();
}
@@ -119,13 +125,13 @@ public final class LossFunctions {
public static abstract class BinaryLoss implements LossFunction {
- protected static void checkTarget(float y) {
+ protected static void checkTarget(final float y) {
if (!(y == 1.f || y == -1.f)) {
throw new IllegalArgumentException("target must be [+1,-1]: " + y);
}
}
- protected static void checkTarget(double y) {
+ protected static void checkTarget(final double y) {
if (!(y == 1.d || y == -1.d)) {
throw new IllegalArgumentException("target must be [+1,-1]: " + y);
}
@@ -150,19 +156,19 @@ public final class LossFunctions {
public static final class SquaredLoss extends RegressionLoss {
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
final float z = p - y;
return z * z * 0.5f;
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
final double z = p - y;
return z * z * 0.5d;
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
return p - y; // 2 (p - y) / 2
}
@@ -197,7 +203,7 @@ public final class LossFunctions {
}
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
float e = y - p;
if (e > 0.f) {
return tau * e;
@@ -207,7 +213,7 @@ public final class LossFunctions {
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
double e = y - p;
if (e > 0.d) {
return tau * e;
@@ -217,7 +223,7 @@ public final class LossFunctions {
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
float e = y - p;
if (e == 0.f) {
return 0.f;
@@ -251,19 +257,19 @@ public final class LossFunctions {
}
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
float loss = Math.abs(y - p) - epsilon;
return (loss > 0.f) ? loss : 0.f;
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
double loss = Math.abs(y - p) - epsilon;
return (loss > 0.d) ? loss : 0.d;
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
if ((y - p) > epsilon) {// real value > predicted value - epsilon
return -1.f;
}
@@ -303,7 +309,7 @@ public final class LossFunctions {
}
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
final float r = p - y;
final float rAbs = Math.abs(r);
if (rAbs <= c) {
@@ -313,7 +319,7 @@ public final class LossFunctions {
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
final double r = p - y;
final double rAbs = Math.abs(r);
if (rAbs <= c) {
@@ -323,7 +329,7 @@ public final class LossFunctions {
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
final float r = p - y;
final float rAbs = Math.abs(r);
if (rAbs <= c) {
@@ -364,19 +370,19 @@ public final class LossFunctions {
}
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
float loss = hingeLoss(p, y, threshold);
return (loss > 0.f) ? loss : 0.f;
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
double loss = hingeLoss(p, y, threshold);
return (loss > 0.d) ? loss : 0.d;
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
float loss = hingeLoss(p, y, threshold);
return (loss > 0.f) ? -y : 0.f;
}
@@ -396,7 +402,7 @@ public final class LossFunctions {
* <code>logloss(p,y) = log(1+exp(-p*y))</code>
*/
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
checkTarget(y);
final float z = y * p;
@@ -410,7 +416,7 @@ public final class LossFunctions {
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
checkTarget(y);
final double z = y * p;
@@ -424,7 +430,7 @@ public final class LossFunctions {
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
checkTarget(y);
float z = y * p;
@@ -449,17 +455,17 @@ public final class LossFunctions {
public static final class SquaredHingeLoss extends BinaryLoss {
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
return squaredHingeLoss(p, y);
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
return squaredHingeLoss(p, y);
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
checkTarget(y);
float d = 1 - (y * p);
@@ -480,7 +486,7 @@ public final class LossFunctions {
public static final class ModifiedHuberLoss extends BinaryLoss {
@Override
- public float loss(float p, float y) {
+ public float loss(final float p, final float y) {
final float z = p * y;
if (z >= 1.f) {
return 0.f;
@@ -491,7 +497,7 @@ public final class LossFunctions {
}
@Override
- public double loss(double p, double y) {
+ public double loss(final double p, final double y) {
final double z = p * y;
if (z >= 1.d) {
return 0.d;
@@ -502,7 +508,7 @@ public final class LossFunctions {
}
@Override
- public float dloss(float p, float y) {
+ public float dloss(final float p, final float y) {
final float z = p * y;
if (z >= 1.f) {
return 0.f;
@@ -552,12 +558,12 @@ public final class LossFunctions {
return Math.log(1.d + Math.exp(-z));
}
- public static float squaredLoss(float p, float y) {
+ public static float squaredLoss(final float p, final float y) {
final float z = p - y;
return z * z * 0.5f;
}
- public static double squaredLoss(double p, double y) {
+ public static double squaredLoss(final double p, final double y) {
final double z = p - y;
return z * z * 0.5d;
}
@@ -576,11 +582,11 @@ public final class LossFunctions {
return threshold - z;
}
- public static float hingeLoss(float p, float y) {
+ public static float hingeLoss(final float p, final float y) {
return hingeLoss(p, y, 1.f);
}
- public static double hingeLoss(double p, double y) {
+ public static double hingeLoss(final double p, final double y) {
return hingeLoss(p, y, 1.d);
}
@@ -603,7 +609,8 @@ public final class LossFunctions {
/**
* Math.abs(target - predicted) - epsilon
*/
- public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
+ public static float epsilonInsensitiveLoss(final float predicted, final float target,
+ final float epsilon) {
return Math.abs(target - predicted) - epsilon;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java
index ad70e61..c6f9877 100644
--- a/core/src/main/java/hivemall/optimizer/Optimizer.java
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -20,6 +20,7 @@ package hivemall.optimizer;
import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
+import hivemall.utils.lang.Primitives;
import java.util.Map;
@@ -39,6 +40,9 @@ public interface Optimizer {
*/
void proceedStep();
+ @Nonnull
+ String getOptimizerName();
+
@NotThreadSafe
static abstract class OptimizerBase implements Optimizer {
@@ -49,7 +53,7 @@ public interface Optimizer {
@Nonnegative
protected int _numStep = 1;
- public OptimizerBase(final Map<String, String> options) {
+ public OptimizerBase(@Nonnull Map<String, String> options) {
this._eta = EtaEstimator.get(options);
this._reg = Regularization.get(options);
}
@@ -61,11 +65,14 @@ public interface Optimizer {
/**
* Update the given weight by the given gradient.
+ *
+ * @return new weight to be set
*/
- protected float update(@Nonnull final IWeightValue weight, float gradient) {
- float g = _reg.regularize(weight.get(), gradient);
+ protected float update(@Nonnull final IWeightValue weight, final float gradient) {
+ float oldWeight = weight.get();
+ float g = _reg.regularize(oldWeight, gradient);
float delta = computeDelta(weight, g);
- float newWeight = weight.get() - _eta.eta(_numStep) * delta;
+ float newWeight = oldWeight - _eta.eta(_numStep) * delta;
weight.set(newWeight);
return newWeight;
}
@@ -73,7 +80,7 @@ public interface Optimizer {
/**
* Compute a delta to update
*/
- protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) {
+ protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) {
return gradient;
}
@@ -89,12 +96,17 @@ public interface Optimizer {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
weightValueReused.set(weight);
update(weightValueReused, gradient);
return weightValueReused.get();
}
+ @Override
+ public String getOptimizerName() {
+ return "sgd";
+ }
+
}
static abstract class AdaGrad extends OptimizerBase {
@@ -104,26 +116,23 @@ public interface Optimizer {
public AdaGrad(Map<String, String> options) {
super(options);
- float eps = 1.0f;
- float scale = 100.0f;
- if (options.containsKey("eps")) {
- eps = Float.parseFloat(options.get("eps"));
- }
- if (options.containsKey("scale")) {
- scale = Float.parseFloat(options.get("scale"));
- }
- this.eps = eps;
- this.scale = scale;
+ this.eps = Primitives.parseFloat(options.get("eps"), 1.0f);
+ this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
}
@Override
- protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) {
+ protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) {
float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient
* (gradient / scale);
weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps);
}
+ @Override
+ public String getOptimizerName() {
+ return "adagrad";
+ }
+
}
static abstract class AdaDelta extends OptimizerBase {
@@ -134,25 +143,13 @@ public interface Optimizer {
public AdaDelta(Map<String, String> options) {
super(options);
- float decay = 0.95f;
- float eps = 1e-6f;
- float scale = 100.0f;
- if (options.containsKey("decay")) {
- decay = Float.parseFloat(options.get("decay"));
- }
- if (options.containsKey("eps")) {
- eps = Float.parseFloat(options.get("eps"));
- }
- if (options.containsKey("scale")) {
- scale = Float.parseFloat(options.get("scale"));
- }
- this.decay = decay;
- this.eps = eps;
- this.scale = scale;
+ this.decay = Primitives.parseFloat(options.get("decay"), 0.95f);
+ this.eps = Primitives.parseFloat(options.get("eps"), 1e-6f);
+ this.scale = Primitives.parseFloat(options.get("scale"), 100.0f);
}
@Override
- protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) {
+ protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) {
float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients();
float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX();
float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad)
@@ -167,6 +164,11 @@ public interface Optimizer {
return delta;
}
+ @Override
+ public String getOptimizerName() {
+ return "adadelta";
+ }
+
}
/**
@@ -183,25 +185,13 @@ public interface Optimizer {
public Adam(Map<String, String> options) {
super(options);
- float beta = 0.9f;
- float gamma = 0.999f;
- float eps_hat = 1e-8f;
- if (options.containsKey("beta")) {
- beta = Float.parseFloat(options.get("beta"));
- }
- if (options.containsKey("gamma")) {
- gamma = Float.parseFloat(options.get("gamma"));
- }
- if (options.containsKey("eps_hat")) {
- eps_hat = Float.parseFloat(options.get("eps_hat"));
- }
- this.beta = beta;
- this.gamma = gamma;
- this.eps_hat = eps_hat;
+ this.beta = Primitives.parseFloat(options.get("beta"), 0.9f);
+ this.gamma = Primitives.parseFloat(options.get("gamma"), 0.999f);
+ this.eps_hat = Primitives.parseFloat(options.get("eps_hat"), 1e-8f);
}
@Override
- protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) {
+ protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) {
float val_m = beta * weight.getM() + (1.f - beta) * gradient;
float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0));
float val_m_hat = val_m / (float) (1.f - Math.pow(beta, _numStep));
@@ -212,36 +202,32 @@ public interface Optimizer {
return delta;
}
+ @Override
+ public String getOptimizerName() {
+ return "adam";
+ }
+
}
static abstract class AdagradRDA extends OptimizerBase {
- private final OptimizerBase optimizerImpl;
-
+ @Nonnull
+ private final AdaGrad optimizerImpl;
private final float lambda;
- public AdagradRDA(final OptimizerBase optimizerImpl, Map<String, String> options) {
+ public AdagradRDA(@Nonnull AdaGrad optimizerImpl, @Nonnull Map<String, String> options) {
super(options);
- // We assume `optimizerImpl` has the `AdaGrad` implementation only
- if (!(optimizerImpl instanceof AdaGrad)) {
- throw new IllegalArgumentException(optimizerImpl.getClass().getSimpleName()
- + " currently does not support RDA regularization");
- }
- float lambda = 1e-6f;
- if (options.containsKey("lambda")) {
- lambda = Float.parseFloat(options.get("lambda"));
- }
this.optimizerImpl = optimizerImpl;
- this.lambda = lambda;
+ this.lambda = Primitives.parseFloat("lambda", 1e-6f);
}
@Override
- protected float update(@Nonnull final IWeightValue weight, float gradient) {
- float new_sum_grad = weight.getSumOfGradients() + gradient;
+ protected float update(@Nonnull final IWeightValue weight, final float gradient) {
+ final float new_sum_grad = weight.getSumOfGradients() + gradient;
// sign(u_{t,i})
- float sign = (new_sum_grad > 0.f) ? 1.f : -1.f;
+ final float sign = (new_sum_grad > 0.f) ? 1.f : -1.f;
// |u_{t,i}|/t - \lambda
- float meansOfGradients = (sign * new_sum_grad / _numStep) - lambda;
+ final float meansOfGradients = (sign * new_sum_grad / _numStep) - lambda;
if (meansOfGradients < 0.f) {
// x_{t,i} = 0
weight.set(0.f);
@@ -258,6 +244,11 @@ public interface Optimizer {
}
}
+ @Override
+ public String getOptimizerName() {
+ return "adagrad_rda";
+ }
+
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java
index 19fecb1..be65609 100644
--- a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java
+++ b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java
@@ -25,8 +25,8 @@ import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
public final class OptimizerOptions {
@@ -63,14 +63,15 @@ public final class OptimizerOptions {
public static void propcessOptions(@Nullable CommandLine cl,
@Nonnull Map<String, String> options) {
- if (cl != null) {
- for (Option opt : cl.getOptions()) {
- String optName = opt.getLongOpt();
- if (optName == null) {
- optName = opt.getOpt();
- }
- options.put(optName, opt.getValue());
+ if (cl == null) {
+ return;
+ }
+ for (Option opt : cl.getOptions()) {
+ String optName = opt.getLongOpt();
+ if (optName == null) {
+ optName = opt.getOpt();
}
+ options.put(optName, opt.getValue());
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/Regularization.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java
index 4939f60..b85ab6f 100644
--- a/core/src/main/java/hivemall/optimizer/Regularization.java
+++ b/core/src/main/java/hivemall/optimizer/Regularization.java
@@ -18,9 +18,10 @@
*/
package hivemall.optimizer;
-import javax.annotation.Nonnull;
import java.util.Map;
+import javax.annotation.Nonnull;
+
public abstract class Regularization {
/** the default regularization term 0.0001 */
public static final float DEFAULT_LAMBDA = 0.0001f;
@@ -120,15 +121,15 @@ public abstract class Regularization {
return new PassThrough(options);
}
- if (regName.toLowerCase().equals("no")) {
+ if ("no".equalsIgnoreCase(regName)) {
return new PassThrough(options);
- } else if (regName.toLowerCase().equals("l1")) {
+ } else if ("l1".equalsIgnoreCase(regName)) {
return new L1(options);
- } else if (regName.toLowerCase().equals("l2")) {
+ } else if ("l2".equalsIgnoreCase(regName)) {
return new L2(options);
- } else if (regName.toLowerCase().equals("elasticnet")) {
+ } else if ("elasticnet".equalsIgnoreCase(regName)) {
return new ElasticNet(options);
- } else if (regName.toLowerCase().equals("rda")) {
+ } else if ("rda".equalsIgnoreCase(regName)) {
// Return `PassThrough` because we need special handling for RDA.
// See an implementation of `Optimizer#RDA`.
return new PassThrough(options);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
index 4a003f3..7bcac1b 100644
--- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
@@ -20,7 +20,6 @@ package hivemall.optimizer;
import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
-import hivemall.optimizer.Optimizer.OptimizerBase;
import hivemall.utils.collections.maps.OpenHashMap;
import java.util.Map;
@@ -37,34 +36,35 @@ public final class SparseOptimizerFactory {
@Nonnull
public static Optimizer create(int ndims, @Nonnull Map<String, String> options) {
final String optimizerName = options.get("optimizer");
- if (optimizerName != null) {
- OptimizerBase optimizerImpl;
- if (optimizerName.toLowerCase().equals("sgd")) {
- optimizerImpl = new Optimizer.SGD(options);
- } else if (optimizerName.toLowerCase().equals("adadelta")) {
- optimizerImpl = new AdaDelta(ndims, options);
- } else if (optimizerName.toLowerCase().equals("adagrad")) {
- optimizerImpl = new AdaGrad(ndims, options);
- } else if (optimizerName.toLowerCase().equals("adam")) {
- optimizerImpl = new Adam(ndims, options);
- } else {
- throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
- }
+ if (optimizerName == null) {
+ throw new IllegalArgumentException("`optimizer` not defined");
+ }
+ final Optimizer optimizerImpl;
+ if ("sgd".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new Optimizer.SGD(options);
+ } else if ("adadelta".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new AdaDelta(ndims, options);
+ } else if ("adagrad".equalsIgnoreCase(optimizerName)) {
// If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
- if (options.get("regularization") != null
- && options.get("regularization").toLowerCase().equals("rda")) {
- optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options);
- }
-
- if (LOG.isInfoEnabled()) {
- LOG.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: "
- + options);
+ if ("rda".equalsIgnoreCase(options.get("regularization"))) {
+ AdaGrad adagrad = new AdaGrad(ndims, options);
+ optimizerImpl = new AdagradRDA(ndims, adagrad, options);
+ } else {
+ optimizerImpl = new AdaGrad(ndims, options);
}
+ } else if ("adam".equalsIgnoreCase(optimizerName)) {
+ optimizerImpl = new Adam(ndims, options);
+ } else {
+ throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
+ }
- return optimizerImpl;
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: "
+ + options);
}
- throw new IllegalArgumentException("`optimizer` not defined");
+
+ return optimizerImpl;
}
@NotThreadSafe
@@ -79,17 +79,15 @@ public final class SparseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
- IWeightValue auxWeight;
- if (auxWeights.containsKey(feature)) {
- auxWeight = auxWeights.get(feature);
- auxWeight.set(weight);
- } else {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
+ IWeightValue auxWeight = auxWeights.get(feature);
+ if (auxWeight == null) {
auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
auxWeights.put(feature, auxWeight);
+ } else {
+ auxWeight.set(weight);
}
- update(auxWeight, gradient);
- return auxWeight.get();
+ return update(auxWeight, gradient);
}
}
@@ -106,17 +104,15 @@ public final class SparseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
- IWeightValue auxWeight;
- if (auxWeights.containsKey(feature)) {
- auxWeight = auxWeights.get(feature);
- auxWeight.set(weight);
- } else {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
+ IWeightValue auxWeight = auxWeights.get(feature);
+ if (auxWeight == null) {
auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
auxWeights.put(feature, auxWeight);
+ } else {
+ auxWeight.set(weight);
}
- update(auxWeight, gradient);
- return auxWeight.get();
+ return update(auxWeight, gradient);
}
}
@@ -133,17 +129,15 @@ public final class SparseOptimizerFactory {
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
- IWeightValue auxWeight;
- if (auxWeights.containsKey(feature)) {
- auxWeight = auxWeights.get(feature);
- auxWeight.set(weight);
- } else {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
+ IWeightValue auxWeight = auxWeights.get(feature);
+ if (auxWeight == null) {
auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
auxWeights.put(feature, auxWeight);
+ } else {
+ auxWeight.set(weight);
}
- update(auxWeight, gradient);
- return auxWeight.get();
+ return update(auxWeight, gradient);
}
}
@@ -154,23 +148,22 @@ public final class SparseOptimizerFactory {
@Nonnull
private final OpenHashMap<Object, IWeightValue> auxWeights;
- public AdagradRDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) {
+ public AdagradRDA(int size, @Nonnull Optimizer.AdaGrad optimizerImpl,
+ @Nonnull Map<String, String> options) {
super(optimizerImpl, options);
this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
}
@Override
- public float update(@Nonnull Object feature, float weight, float gradient) {
- IWeightValue auxWeight;
- if (auxWeights.containsKey(feature)) {
- auxWeight = auxWeights.get(feature);
- auxWeight.set(weight);
- } else {
+ public float update(@Nonnull final Object feature, final float weight, final float gradient) {
+ IWeightValue auxWeight = auxWeights.get(feature);
+ if (auxWeight == null) {
auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
auxWeights.put(feature, auxWeight);
+ } else {
+ auxWeight.set(weight);
}
- update(auxWeight, gradient);
- return auxWeight.get();
+ return update(auxWeight, gradient);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
index 5137dd3..160d92d 100644
--- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
@@ -50,8 +50,9 @@ public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF {
}
@Override
- protected void checkLossFunction(@Nonnull LossFunction lossFunction) throws UDFArgumentException {
- if (lossFunction.forBinaryClassification()) {
+ protected void checkLossFunction(@Nonnull LossFunction lossFunction)
+ throws UDFArgumentException {
+ if (!lossFunction.forRegression()) {
throw new UDFArgumentException("The loss function `" + lossFunction.getType()
+ "` is not designed for regression");
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
index d7ae8d6..5ce34a4 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java
@@ -354,9 +354,9 @@ public class IntOpenHashMap<V> implements Externalizable {
return key & 0x7fffffff;
}
- protected void recordAccess(int idx) {};
+ protected void recordAccess(int idx) {}
- protected void recordRemoval(int idx) {};
+ protected void recordRemoval(int idx) {}
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(_threshold);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
index e558e67..6ed783c 100644
--- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
@@ -38,7 +38,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
import org.junit.Assert;
import org.junit.Test;
@@ -107,8 +106,8 @@ public class GeneralClassifierUDTFTest {
udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
- float cumLossPrev = Float.MAX_VALUE;
- float cumLoss = 0.f;
+ double cumLossPrev = Double.MAX_VALUE;
+ double cumLoss = 0.d;
int it = 0;
while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) {
cumLossPrev = cumLoss;
@@ -119,7 +118,7 @@ public class GeneralClassifierUDTFTest {
cumLoss = udtf.getCumulativeLoss();
println("Iter: " + ++it + ", Cumulative loss: " + cumLoss);
}
- Assert.assertTrue(cumLoss / samplesList.size() < 0.5f);
+ Assert.assertTrue(cumLoss / samplesList.size() < 0.5d);
int numTests = 0;
int numCorrect = 0;
@@ -176,6 +175,7 @@ public class GeneralClassifierUDTFTest {
}
}
+ @SuppressWarnings("unchecked")
@Test
public void testNews20() throws IOException, ParseException, HiveException {
int nIter = 10;
@@ -205,7 +205,7 @@ public class GeneralClassifierUDTFTest {
udtf.process(new Object[] {words, label});
labels.add(label);
- wordsList.add((ArrayList) words.clone());
+ wordsList.add((ArrayList<String>) words.clone());
words.clear();
line = news20.readLine();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5e27993b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
index 15dcc22..df5c643 100644
--- a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
+++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
@@ -121,8 +121,8 @@ public class GeneralRegressionUDTFTest {
udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
- float cumLossPrev = Float.MAX_VALUE;
- float cumLoss = 0.f;
+ double cumLossPrev = Double.MAX_VALUE;
+ double cumLoss = 0.d;
int it = 0;
while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) {
cumLossPrev = cumLoss;
@@ -133,7 +133,7 @@ public class GeneralRegressionUDTFTest {
cumLoss = udtf.getCumulativeLoss();
println("Iter: " + ++it + ", Cumulative loss: " + cumLoss);
}
- Assert.assertTrue(cumLoss / numTrain < 0.1f);
+ Assert.assertTrue(cumLoss / numTrain < 0.1d);
float accum = 0.f;