You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/08/28 15:42:50 UTC
incubator-hivemall git commit: [HIVEMALL-212] Fix
Classifier/Regressor not to forward zero weighted values
Repository: incubator-hivemall
Updated Branches:
refs/heads/master f8beee36b -> 4ca1c19c7
[HIVEMALL-212] Fix Classifier/Regressor not to forward zero weighted values
## What changes were proposed in this pull request?
Feature with weight = 0.0 need not to be saved in the prediction model. It is preferable to reduce the size of prediction model. So, this PR fixes Classifier/Regressor not to forward zero weighted values
## What type of PR is it?
Improvement
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-212
## How was this patch tested?
unit tests and manual tests
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #157 from myui/HIVEMALL-212.
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4ca1c19c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4ca1c19c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4ca1c19c
Branch: refs/heads/master
Commit: 4ca1c19c7e00120cb0abd42d6c1f1b48176846e8
Parents: f8beee3
Author: Makoto Yui <my...@apache.org>
Authored: Wed Aug 29 00:42:45 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Wed Aug 29 00:42:45 2018 +0900
----------------------------------------------------------------------
.../java/hivemall/GeneralLearnerBaseUDTF.java | 33 ++++++++++++++++----
.../hivemall/ensemble/ArgminKLDistanceUDAF.java | 14 ++++++---
.../optimizer/SparseOptimizerFactory.java | 6 +++-
3 files changed, 42 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index 5c3967b..0198e77 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -452,9 +452,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
cvState.incrLoss(loss); // retain cumulative loss to check convergence
final float dloss = lossFunction.dloss(predicted, target);
+ if (dloss == 0.f) {
+ optimizer.proceedStep();
+ return;
+ }
+
if (is_mini_batch) {
accumulateUpdate(features, dloss);
-
if (sampled >= mini_batch_size) {
batchUpdate();
}
@@ -494,7 +498,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) {
Object feature = e.getKey();
FloatAccumulator v = e.getValue();
- float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M)
+ final float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M)
+ if (new_weight == 0.f) {
+ model.delete(feature);
+ continue;
+ }
model.setWeight(feature, new_weight);
}
@@ -507,7 +515,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
Object feature = f.getFeature();
float xi = f.getValueAsFloat();
float weight = model.getWeight(feature);
- float new_weight = optimizer.update(feature, weight, dloss * xi);
+ final float new_weight = optimizer.update(feature, weight, dloss * xi);
+ if (new_weight == 0.f) {
+ model.delete(feature);
+ continue;
+ }
model.setWeight(feature, new_weight);
}
}
@@ -701,9 +713,14 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
+ final float v = probe.get();
+ final float cv = probe.getCovariance();
+ if (v == 0.f && cv == 0.f) {
+ continue;
+ }
+ fv.set(v);
+ cov.set(cv);
Object k = itor.getKey();
- fv.set(probe.get());
- cov.set(probe.getCovariance());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forwardMapObj[2] = cov;
@@ -720,8 +737,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
+ final float v = probe.get();
+ if (v == 0.f) {
+ continue;
+ }
+ fv.set(v);
Object k = itor.getKey();
- fv.set(probe.get());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forward(forwardMapObj);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
index 136ca0d..774db6f 100644
--- a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
+++ b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
@@ -38,8 +38,8 @@ public final class ArgminKLDistanceUDAF extends UDAF {
float sum_inv_covar;
PartialResult() {
- this.sum_mean_div_covar = 0f;
- this.sum_inv_covar = 0f;
+ this.sum_mean_div_covar = 0.f;
+ this.sum_inv_covar = 0.f;
}
}
@@ -54,7 +54,10 @@ public final class ArgminKLDistanceUDAF extends UDAF {
if (partial == null) {
this.partial = new PartialResult();
}
- float covar_f = covar.get();
+ final float covar_f = covar.get();
+ if (covar_f == 0.f) {// avoid null division
+ return true;
+ }
partial.sum_mean_div_covar += (mean.get() / covar_f);
partial.sum_inv_covar += (1.f / covar_f);
return true;
@@ -80,7 +83,10 @@ public final class ArgminKLDistanceUDAF extends UDAF {
if (partial == null) {
return null;
}
- float mean = (1f / partial.sum_inv_covar) * partial.sum_mean_div_covar;
+ if (partial.sum_inv_covar == 0.f) {// avoid null division
+ return new FloatWritable(0.f);
+ }
+ float mean = (1.f / partial.sum_inv_covar) * partial.sum_mean_div_covar;
return new FloatWritable(mean);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
index 7cf61d8..1254740 100644
--- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
@@ -177,7 +177,11 @@ public final class SparseOptimizerFactory {
} else {
auxWeight.set(weight);
}
- return update(auxWeight, gradient);
+ final float newWeight = update(auxWeight, gradient);
+ if (newWeight == 0.f) {
+ auxWeights.remove(feature);
+ }
+ return newWeight;
}
}