You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/08/28 15:42:50 UTC

incubator-hivemall git commit: [HIVEMALL-212] Fix Classifier/Regressor not to forward zero weighted values

Repository: incubator-hivemall
Updated Branches:
  refs/heads/master f8beee36b -> 4ca1c19c7


[HIVEMALL-212] Fix Classifier/Regressor not to forward zero weighted values

## What changes were proposed in this pull request?

Feature with weight = 0.0  need not to be saved in the prediction model. It is preferable to reduce the size of prediction model. So, this PR fixes Classifier/Regressor not to forward zero weighted values

## What type of PR is it?

Improvement

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-212

## How was this patch tested?

unit tests and manual tests

## Checklist

(Please remove this section if not needed; check `x` for YES, blank for NO)

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <my...@apache.org>

Closes #157 from myui/HIVEMALL-212.


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4ca1c19c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4ca1c19c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4ca1c19c

Branch: refs/heads/master
Commit: 4ca1c19c7e00120cb0abd42d6c1f1b48176846e8
Parents: f8beee3
Author: Makoto Yui <my...@apache.org>
Authored: Wed Aug 29 00:42:45 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Wed Aug 29 00:42:45 2018 +0900

----------------------------------------------------------------------
 .../java/hivemall/GeneralLearnerBaseUDTF.java   | 33 ++++++++++++++++----
 .../hivemall/ensemble/ArgminKLDistanceUDAF.java | 14 ++++++---
 .../optimizer/SparseOptimizerFactory.java       |  6 +++-
 3 files changed, 42 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index 5c3967b..0198e77 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -452,9 +452,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
         cvState.incrLoss(loss); // retain cumulative loss to check convergence
 
         final float dloss = lossFunction.dloss(predicted, target);
+        if (dloss == 0.f) {
+            optimizer.proceedStep();
+            return;
+        }
+
         if (is_mini_batch) {
             accumulateUpdate(features, dloss);
-
             if (sampled >= mini_batch_size) {
                 batchUpdate();
             }
@@ -494,7 +498,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
         for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) {
             Object feature = e.getKey();
             FloatAccumulator v = e.getValue();
-            float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M)
+            final float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M)
+            if (new_weight == 0.f) {
+                model.delete(feature);
+                continue;
+            }
             model.setWeight(feature, new_weight);
         }
 
@@ -507,7 +515,11 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
             Object feature = f.getFeature();
             float xi = f.getValueAsFloat();
             float weight = model.getWeight(feature);
-            float new_weight = optimizer.update(feature, weight, dloss * xi);
+            final float new_weight = optimizer.update(feature, weight, dloss * xi);
+            if (new_weight == 0.f) {
+                model.delete(feature);
+                continue;
+            }
             model.setWeight(feature, new_weight);
         }
     }
@@ -701,9 +713,14 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
                 if (!probe.isTouched()) {
                     continue; // skip outputting untouched weights
                 }
+                final float v = probe.get();
+                final float cv = probe.getCovariance();
+                if (v == 0.f && cv == 0.f) {
+                    continue;
+                }
+                fv.set(v);
+                cov.set(cv);
                 Object k = itor.getKey();
-                fv.set(probe.get());
-                cov.set(probe.getCovariance());
                 forwardMapObj[0] = k;
                 forwardMapObj[1] = fv;
                 forwardMapObj[2] = cov;
@@ -720,8 +737,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
                 if (!probe.isTouched()) {
                     continue; // skip outputting untouched weights
                 }
+                final float v = probe.get();
+                if (v == 0.f) {
+                    continue;
+                }
+                fv.set(v);
                 Object k = itor.getKey();
-                fv.set(probe.get());
                 forwardMapObj[0] = k;
                 forwardMapObj[1] = fv;
                 forward(forwardMapObj);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
index 136ca0d..774db6f 100644
--- a/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
+++ b/core/src/main/java/hivemall/ensemble/ArgminKLDistanceUDAF.java
@@ -38,8 +38,8 @@ public final class ArgminKLDistanceUDAF extends UDAF {
             float sum_inv_covar;
 
             PartialResult() {
-                this.sum_mean_div_covar = 0f;
-                this.sum_inv_covar = 0f;
+                this.sum_mean_div_covar = 0.f;
+                this.sum_inv_covar = 0.f;
             }
         }
 
@@ -54,7 +54,10 @@ public final class ArgminKLDistanceUDAF extends UDAF {
             if (partial == null) {
                 this.partial = new PartialResult();
             }
-            float covar_f = covar.get();
+            final float covar_f = covar.get();
+            if (covar_f == 0.f) {// avoid null division
+                return true;
+            }
             partial.sum_mean_div_covar += (mean.get() / covar_f);
             partial.sum_inv_covar += (1.f / covar_f);
             return true;
@@ -80,7 +83,10 @@ public final class ArgminKLDistanceUDAF extends UDAF {
             if (partial == null) {
                 return null;
             }
-            float mean = (1f / partial.sum_inv_covar) * partial.sum_mean_div_covar;
+            if (partial.sum_inv_covar == 0.f) {// avoid null division
+                return new FloatWritable(0.f);
+            }
+            float mean = (1.f / partial.sum_inv_covar) * partial.sum_mean_div_covar;
             return new FloatWritable(mean);
         }
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4ca1c19c/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
index 7cf61d8..1254740 100644
--- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
+++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
@@ -177,7 +177,11 @@ public final class SparseOptimizerFactory {
             } else {
                 auxWeight.set(weight);
             }
-            return update(auxWeight, gradient);
+            final float newWeight = update(auxWeight, gradient);
+            if (newWeight == 0.f) {
+                auxWeights.remove(feature);
+            }
+            return newWeight;
         }
 
     }