You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/08/24 09:44:45 UTC
incubator-hivemall git commit: [HIVEMALL-211][BUGFIX] Fixed Optimizer
for regularization updates
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 61711fbc2 -> f8beee36b
[HIVEMALL-211][BUGFIX] Fixed Optimizer for regularization updates
## What changes were proposed in this pull request?
This PR fixes a bug of regularization scheme of Optimizer.
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-211
## How was this patch tested?
unit tests, manual tests on EMR
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #156 from myui/HIVEMALL-211.
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f8beee36
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f8beee36
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f8beee36
Branch: refs/heads/master
Commit: f8beee36b36274d8eb8948f6838aacd955817eba
Parents: 61711fb
Author: Makoto Yui <my...@apache.org>
Authored: Fri Aug 24 18:44:40 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Fri Aug 24 18:44:40 2018 +0900
----------------------------------------------------------------------
.../hivemall/optimizer/DenseOptimizerFactory.java | 2 +-
.../main/java/hivemall/optimizer/LossFunctions.java | 3 +++
.../src/main/java/hivemall/optimizer/Optimizer.java | 16 ++++++++--------
3 files changed, 12 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
index b1fe917..5985868 100644
--- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
@@ -48,7 +48,7 @@ public final class DenseOptimizerFactory {
&& "adagrad".equalsIgnoreCase(optimizerName) == false) {
throw new IllegalArgumentException(
"`-regularization rda` is only supported for AdaGrad but `-optimizer "
- + optimizerName);
+ + optimizerName + "`. Please specify `-regularization l1` and so on.");
}
final Optimizer optimizerImpl;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java
index c4705c0..f76eb0e 100644
--- a/core/src/main/java/hivemall/optimizer/LossFunctions.java
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@ -584,6 +584,9 @@ public final class LossFunctions {
}
}
+ /**
+ * logistic loss function where target is 0 (negative) or 1 (positive).
+ */
public static float logisticLoss(final float target, final float predicted) {
if (predicted > -100.d) {
return target - (float) MathUtils.sigmoid(predicted);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f8beee36/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java
index 4b1ef0a..0cbac42 100644
--- a/core/src/main/java/hivemall/optimizer/Optimizer.java
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -70,9 +70,8 @@ public interface Optimizer {
*/
protected float update(@Nonnull final IWeightValue weight, final float gradient) {
float oldWeight = weight.get();
- float g = _reg.regularize(oldWeight, gradient);
- float delta = computeDelta(weight, g);
- float newWeight = oldWeight - _eta.eta(_numStep) * delta;
+ float delta = computeDelta(weight, gradient);
+ float newWeight = oldWeight - _eta.eta(_numStep) * _reg.regularize(oldWeight, delta);
weight.set(newWeight);
return newWeight;
}
@@ -123,10 +122,10 @@ public interface Optimizer {
@Override
protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) {
- float new_scaled_sum_sqgrad =
- weight.getSumOfSquaredGradients() + gradient * (gradient / scale);
- weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
- return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps);
+ float old_scaled_gg = weight.getSumOfSquaredGradients();
+ float new_scaled_gg = old_scaled_gg + gradient * (gradient / scale);
+ weight.setSumOfSquaredGradients(new_scaled_gg);
+ return (float) (gradient / Math.sqrt(eps + ((double) old_scaled_gg) * scale));
}
@Override
@@ -156,7 +155,8 @@ public interface Optimizer {
float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad)
+ ((1.f - decay) * gradient * (gradient / scale));
float delta = (float) Math.sqrt(
- (old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * scale + eps)) * gradient;
+ (old_sum_squared_delta_x + eps) / ((double) new_scaled_sum_sqgrad * scale + eps))
+ * gradient;
float new_sum_squared_delta_x =
(decay * old_sum_squared_delta_x) + ((1.f - decay) * delta * delta);
weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);