You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 07:04:16 UTC
[14/50] [abbrv] incubator-hivemall git commit: Revert some
modifications
Revert some modifications
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3620eb89
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3620eb89
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3620eb89
Branch: refs/heads/JIRA-22/pr-285
Commit: 3620eb89993db22ce8aee924d3cc0df33a5f9618
Parents: f81948c
Author: Takeshi YAMAMURO <li...@gmail.com>
Authored: Wed Sep 21 01:52:22 2016 +0900
Committer: Takeshi YAMAMURO <li...@gmail.com>
Committed: Wed Sep 21 01:55:59 2016 +0900
----------------------------------------------------------------------
.../src/main/java/hivemall/LearnerBaseUDTF.java | 33 ++
.../hivemall/classifier/AROWClassifierUDTF.java | 2 +-
.../hivemall/classifier/AdaGradRDAUDTF.java | 125 +++++++-
.../classifier/BinaryOnlineClassifierUDTF.java | 10 +
.../classifier/GeneralClassifierUDTF.java | 1 +
.../classifier/PassiveAggressiveUDTF.java | 2 +-
.../main/java/hivemall/model/DenseModel.java | 86 ++++-
.../main/java/hivemall/model/NewDenseModel.java | 293 +++++++++++++++++
.../model/NewSpaceEfficientDenseModel.java | 317 +++++++++++++++++++
.../java/hivemall/model/NewSparseModel.java | 197 ++++++++++++
.../java/hivemall/model/PredictionModel.java | 3 +
.../model/SpaceEfficientDenseModel.java | 92 +++++-
.../main/java/hivemall/model/SparseModel.java | 19 +-
.../model/SynchronizedModelWrapper.java | 6 +
.../hivemall/regression/AROWRegressionUDTF.java | 2 +-
.../java/hivemall/regression/AdaDeltaUDTF.java | 118 ++++++-
.../java/hivemall/regression/AdaGradUDTF.java | 119 ++++++-
.../regression/GeneralRegressionUDTF.java | 1 +
.../java/hivemall/regression/LogressUDTF.java | 65 +++-
.../PassiveAggressiveRegressionUDTF.java | 2 +-
.../hivemall/regression/RegressionBaseUDTF.java | 12 +-
.../NewSpaceEfficientNewDenseModelTest.java | 60 ++++
.../model/SpaceEfficientDenseModelTest.java | 60 ----
.../java/hivemall/mix/server/MixServerTest.java | 14 +-
.../hivemall/mix/server/MixServerSuite.scala | 4 +-
25 files changed, 1512 insertions(+), 131 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/LearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java
index 7fd5190..4cf3c7f 100644
--- a/core/src/main/java/hivemall/LearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java
@@ -25,6 +25,9 @@ import hivemall.model.DenseModel;
import hivemall.model.PredictionModel;
import hivemall.model.SpaceEfficientDenseModel;
import hivemall.model.SparseModel;
+import hivemall.model.NewDenseModel;
+import hivemall.model.NewSparseModel;
+import hivemall.model.NewSpaceEfficientDenseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
@@ -199,6 +202,36 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
return model;
}
+ protected PredictionModel createNewModel(String label) {
+ PredictionModel model;
+ final boolean useCovar = useCovariance();
+ if (dense_model) {
+ if (disable_halffloat == false && model_dims > 16777216) {
+ logger.info("Build a space efficient dense model with " + model_dims
+ + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
+ model = new NewSpaceEfficientDenseModel(model_dims, useCovar);
+ } else {
+ logger.info("Build a dense model with initial with " + model_dims
+ + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
+ model = new NewDenseModel(model_dims, useCovar);
+ }
+ } else {
+ int initModelSize = getInitialModelSize();
+ logger.info("Build a sparse model with initial with " + initModelSize
+ + " initial dimensions");
+ model = new NewSparseModel(initModelSize, useCovar);
+ }
+ if (mixConnectInfo != null) {
+ model.configureClock();
+ model = new SynchronizedModelWrapper(model);
+ MixClient client = configureMixClient(mixConnectInfo, label, model);
+ model.configureMix(client, mixCancel);
+ this.mixClient = client;
+ }
+ assert (model != null);
+ return model;
+ }
+
// If a model implements a optimizer, it must override this
protected Map<String, String> getOptimzierOptions() {
return null;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
index ac8afcb..b42ab05 100644
--- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
@@ -18,11 +18,11 @@
*/
package hivemall.classifier;
-import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
index a6714f4..b512a34 100644
--- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
@@ -18,13 +18,128 @@
*/
package hivemall.classifier;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF2;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+/**
+ * @deprecated Use {@link hivemall.classifier.GeneralClassifierUDTF} instead
+ */
@Deprecated
-public final class AdaGradRDAUDTF extends GeneralClassifierUDTF {
+@Description(name = "train_adagrad_rda",
+ value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])"
+ + " - Returns a relation consists of <string|int|bigint feature, float weight>",
+ extended = "Build a prediction model by Adagrad+RDA regularization binary classifier")
+public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF {
+
+ private float eta;
+ private float lambda;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, false, true);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]");
+ opts.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [default: 100]");
+ return opts;
+ }
- public AdaGradRDAUDTF() {
- optimizerOptions.put("optimizer", "AdaGrad");
- optimizerOptions.put("regularization", "RDA");
- optimizerOptions.put("lambda", "1e-6");
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.eta = 0.1f;
+ this.lambda = 1E-6f;
+ this.scaling = 100f;
+ } else {
+ this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.1f);
+ this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 1E-6f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
}
+ @Override
+ protected void train(@Nonnull final FeatureValue[] features, final int label) {
+ final float y = label > 0 ? 1.f : -1.f;
+
+ float p = predict(features);
+ float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p
+ if (loss <= 0.f) { // max(0, 1 - y * p)
+ return;
+ }
+ // subgradient => -y * W dot xi
+ update(features, y, count);
+ }
+
+ protected void update(@Nonnull final FeatureValue[] features, final float y, final int t) {
+ for (FeatureValue f : features) {// w[f] += y * x[f]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ updateWeight(x, xi, y, t);
+ }
+ }
+
+ protected void updateWeight(@Nonnull final Object x, final float xi, final float y,
+ final float t) {
+ final float gradient = -y * xi;
+ final float scaled_gradient = gradient * scaling;
+
+ float scaled_sum_sqgrad = 0.f;
+ float scaled_sum_grad = 0.f;
+ IWeightValue old = model.get(x);
+ if (old != null) {
+ scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ scaled_sum_grad = old.getSumOfGradients();
+ }
+ scaled_sum_grad += scaled_gradient;
+ scaled_sum_sqgrad += (scaled_gradient * scaled_gradient);
+
+ float sum_grad = scaled_sum_grad * scaling;
+ double sum_sqgrad = scaled_sum_sqgrad * scaling;
+
+ // sign(u_{t,i})
+ float sign = (sum_grad > 0.f) ? 1.f : -1.f;
+ // |u_{t,i}|/t - \lambda
+ float meansOfGradients = sign * sum_grad / t - lambda;
+ if (meansOfGradients < 0.f) {
+ // x_{t,i} = 0
+ model.delete(x);
+ } else {
+ // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda)
+ float weight = -1.f * sign * eta * t * meansOfGradients / (float) Math.sqrt(sum_sqgrad);
+ IWeightValue new_w = new WeightValueParamsF2(weight, scaled_sum_sqgrad, scaled_sum_grad);
+ model.set(x, new_w);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index 0ee5d5f..efeeb9d 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -60,6 +60,16 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
protected Optimizer optimizerImpl;
protected int count;
+ private boolean enableNewModel;
+
+ public BinaryOnlineClassifierUDTF() {
+ this.enableNewModel = false;
+ }
+
+ public BinaryOnlineClassifierUDTF(boolean enableNewModel) {
+ this.enableNewModel = enableNewModel;
+ }
+
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
index feebadd..12bd481 100644
--- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
@@ -39,6 +39,7 @@ public class GeneralClassifierUDTF extends BinaryOnlineClassifierUDTF {
protected final Map<String, String> optimizerOptions;
public GeneralClassifierUDTF() {
+ super(true); // This enables new model interfaces
this.optimizerOptions = new HashMap<String, String>();
// Set default values
optimizerOptions.put("optimizer", "adagrad");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
index 9e404cd..191a7b5 100644
--- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
+++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
@@ -18,9 +18,9 @@
*/
package hivemall.classifier;
-import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/DenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java
index 6956875..f142cc1 100644
--- a/core/src/main/java/hivemall/model/DenseModel.java
+++ b/core/src/main/java/hivemall/model/DenseModel.java
@@ -18,18 +18,21 @@
*/
package hivemall.model;
-import java.util.Arrays;
-import javax.annotation.Nonnull;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.model.WeightValue.WeightValueParamsF2;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Copyable;
import hivemall.utils.math.MathUtils;
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
public final class DenseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(DenseModel.class);
@@ -37,6 +40,13 @@ public final class DenseModel extends AbstractPredictionModel {
private float[] weights;
private float[] covars;
+ // optional values for adagrad
+ private float[] sum_of_squared_gradients;
+ // optional value for adadelta
+ private float[] sum_of_squared_delta_x;
+ // optional value for adagrad+rda
+ private float[] sum_of_gradients;
+
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -57,6 +67,9 @@ public final class DenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
+ this.sum_of_squared_gradients = null;
+ this.sum_of_squared_delta_x = null;
+ this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -72,6 +85,20 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ if (sum_of_squared_gradients) {
+ this.sum_of_squared_gradients = new float[size];
+ }
+ if (sum_of_squared_delta_x) {
+ this.sum_of_squared_delta_x = new float[size];
+ }
+ if (sum_of_gradients) {
+ this.sum_of_gradients = new float[size];
+ }
+ }
+
+ @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -102,7 +129,16 @@ public final class DenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, 1.f);
}
- if(clocks != null) {
+ if (sum_of_squared_gradients != null) {
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ }
+ if (sum_of_squared_delta_x != null) {
+ this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
+ }
+ if (sum_of_gradients != null) {
+ this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
+ }
+ if (clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -116,7 +152,17 @@ public final class DenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
- if(covars != null) {
+ if (sum_of_squared_gradients != null) {
+ if (sum_of_squared_delta_x != null) {
+ return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
+ sum_of_squared_delta_x[i]);
+ } else if (sum_of_gradients != null) {
+ return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
+ sum_of_gradients[i]);
+ } else {
+ return (T) new WeightValueParamsF1(weights[i], sum_of_squared_gradients[i]);
+ }
+ } else if (covars != null) {
return (T) new WeightValueWithCovar(weights[i], covars[i]);
} else {
return (T) new WeightValue(weights[i]);
@@ -135,6 +181,15 @@ public final class DenseModel extends AbstractPredictionModel {
covar = value.getCovariance();
covars[i] = covar;
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = value.getSumOfGradients();
+ }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -158,6 +213,15 @@ public final class DenseModel extends AbstractPredictionModel {
if (covars != null) {
covars[i] = 1.f;
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = 0.f;
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = 0.f;
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = 0.f;
+ }
// avoid clock/delta
}
@@ -171,10 +235,8 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- int i = HiveUtils.parseInt(feature);
- ensureCapacity(i);
- weights[i] = value;
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java
new file mode 100644
index 0000000..920794c
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewDenseModel.java
@@ -0,0 +1,293 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import java.util.Arrays;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.MathUtils;
+
+public final class NewDenseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewDenseModel.class);
+
+ private int size;
+ private float[] weights;
+ private float[] covars;
+
+ // optional value for MIX
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
+ public NewDenseModel(int ndims) {
+ this(ndims, false);
+ }
+
+ public NewDenseModel(int ndims, boolean withCovar) {
+ super();
+ int size = ndims + 1;
+ this.size = size;
+ this.weights = new float[size];
+ if (withCovar) {
+ float[] covars = new float[size];
+ Arrays.fill(covars, 1f);
+ this.covars = covars;
+ } else {
+ this.covars = null;
+ }
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return true;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ if (clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
+ }
+
+ private void ensureCapacity(final int index) {
+ if (index >= size) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ int oldSize = size;
+ logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
+ + bits + " bits)");
+ this.size = newSize;
+ this.weights = Arrays.copyOf(weights, newSize);
+ if (covars != null) {
+ this.covars = Arrays.copyOf(covars, newSize);
+ Arrays.fill(covars, oldSize, newSize, 1.f);
+ }
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return null;
+ }
+ if(covars != null) {
+ return (T) new WeightValueWithCovar(weights[i], covars[i]);
+ } else {
+ return (T) new WeightValue(weights[i]);
+ }
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(Object feature, T value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ float weight = value.get();
+ weights[i] = weight;
+ float covar = 1.f;
+ boolean hasCovar = value.hasCovariance();
+ if (hasCovar) {
+ covar = value.getCovariance();
+ covars[i] = covar;
+ }
+ short clock = 0;
+ int delta = 0;
+ if (clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta, hasCovar);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return;
+ }
+ weights[i] = 0.f;
+ if (covars != null) {
+ covars[i] = 1.f;
+ }
+ // avoid clock/delta
+ }
+
+ @Override
+ public float getWeight(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 0f;
+ }
+ return weights[i];
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weights[i] = value;
+ }
+
+ @Override
+ public float getCovariance(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 1f;
+ }
+ return covars[i];
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ weights[i] = weight;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, float covar, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ weights[i] = weight;
+ covars[i] = covar;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public boolean contains(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return false;
+ }
+ float w = weights[i];
+ return w != 0.f;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) new Itr();
+ }
+
+ private final class Itr implements IMapIterator<Number, IWeightValue> {
+
+ private int cursor;
+ private final WeightValueWithCovar tmpWeight;
+
+ private Itr() {
+ this.cursor = -1;
+ this.tmpWeight = new WeightValueWithCovar();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < size;
+ }
+
+ @Override
+ public int next() {
+ ++cursor;
+ if (!hasNext()) {
+ return -1;
+ }
+ return cursor;
+ }
+
+ @Override
+ public Integer getKey() {
+ return cursor;
+ }
+
+ @Override
+ public IWeightValue getValue() {
+ if (covars == null) {
+ float w = weights[cursor];
+ WeightValue v = new WeightValue(w);
+ v.setTouched(w != 0f);
+ return v;
+ } else {
+ float w = weights[cursor];
+ float cov = covars[cursor];
+ WeightValueWithCovar v = new WeightValueWithCovar(w, cov);
+ v.setTouched(w != 0.f || cov != 1.f);
+ return v;
+ }
+ }
+
+ @Override
+ public <T extends Copyable<IWeightValue>> void getValue(T probe) {
+ float w = weights[cursor];
+ tmpWeight.value = w;
+ float cov = 1.f;
+ if (covars != null) {
+ cov = covars[cursor];
+ tmpWeight.setCovariance(cov);
+ }
+ tmpWeight.setTouched(w != 0.f || cov != 1.f);
+ probe.copyFrom(tmpWeight);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
new file mode 100644
index 0000000..48eb62a
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
@@ -0,0 +1,317 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.lang.HalfFloat;
+import hivemall.utils.math.MathUtils;
+
+import java.util.Arrays;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewSpaceEfficientDenseModel.class);
+
+ private int size;
+ private short[] weights;
+ private short[] covars;
+
+ // optional value for MIX
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
+ public NewSpaceEfficientDenseModel(int ndims) {
+ this(ndims, false);
+ }
+
+ public NewSpaceEfficientDenseModel(int ndims, boolean withCovar) {
+ super();
+ int size = ndims + 1;
+ this.size = size;
+ this.weights = new short[size];
+ if (withCovar) {
+ short[] covars = new short[size];
+ Arrays.fill(covars, HalfFloat.ONE);
+ this.covars = covars;
+ } else {
+ this.covars = null;
+ }
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return true;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ if (clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
+ }
+
+ private float getWeight(final int i) {
+ final short w = weights[i];
+ return (w == HalfFloat.ZERO) ? HalfFloat.ZERO : HalfFloat.halfFloatToFloat(w);
+ }
+
+ private float getCovar(final int i) {
+ return HalfFloat.halfFloatToFloat(covars[i]);
+ }
+
+ private void _setWeight(final int i, final float v) {
+ if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
+ throw new IllegalArgumentException("Acceptable maximum weight is "
+ + HalfFloat.MAX_FLOAT + ": " + v);
+ }
+ weights[i] = HalfFloat.floatToHalfFloat(v);
+ }
+
+ private void setCovar(final int i, final float v) {
+ HalfFloat.checkRange(v);
+ covars[i] = HalfFloat.floatToHalfFloat(v);
+ }
+
+ private void ensureCapacity(final int index) {
+ if (index >= size) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ int oldSize = size;
+ logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
+ + bits + " bits)");
+ this.size = newSize;
+ this.weights = Arrays.copyOf(weights, newSize);
+ if (covars != null) {
+ this.covars = Arrays.copyOf(covars, newSize);
+ Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
+ }
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return null;
+ }
+
+ if(covars != null) {
+ return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
+ } else {
+ return (T) new WeightValue(getWeight(i));
+ }
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(Object feature, T value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ float weight = value.get();
+ _setWeight(i, weight);
+ float covar = 1.f;
+ boolean hasCovar = value.hasCovariance();
+ if (hasCovar) {
+ covar = value.getCovariance();
+ setCovar(i, covar);
+ }
+ short clock = 0;
+ int delta = 0;
+ if (clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta, hasCovar);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return;
+ }
+ _setWeight(i, 0.f);
+ if(covars != null) {
+ setCovar(i, 1.f);
+ }
+ // avoid clock/delta
+ }
+
+ @Override
+ public float getWeight(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 0f;
+ }
+ return getWeight(i);
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ _setWeight(i, value);
+ }
+
+ @Override
+ public float getCovariance(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 1f;
+ }
+ return getCovar(i);
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ _setWeight(i, weight);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, float covar, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ _setWeight(i, weight);
+ setCovar(i, covar);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public boolean contains(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return false;
+ }
+ float w = getWeight(i);
+ return w != 0.f;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) new Itr();
+ }
+
+ private final class Itr implements IMapIterator<Number, IWeightValue> {
+
+ private int cursor;
+ private final WeightValueWithCovar tmpWeight;
+
+ private Itr() {
+ this.cursor = -1;
+ this.tmpWeight = new WeightValueWithCovar();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < size;
+ }
+
+ @Override
+ public int next() {
+ ++cursor;
+ if (!hasNext()) {
+ return -1;
+ }
+ return cursor;
+ }
+
+ @Override
+ public Integer getKey() {
+ return cursor;
+ }
+
+ @Override
+ public IWeightValue getValue() {
+ if (covars == null) {
+ float w = getWeight(cursor);
+ WeightValue v = new WeightValue(w);
+ v.setTouched(w != 0f);
+ return v;
+ } else {
+ float w = getWeight(cursor);
+ float cov = getCovar(cursor);
+ WeightValueWithCovar v = new WeightValueWithCovar(w, cov);
+ v.setTouched(w != 0.f || cov != 1.f);
+ return v;
+ }
+ }
+
+ @Override
+ public <T extends Copyable<IWeightValue>> void getValue(T probe) {
+ float w = getWeight(cursor);
+ tmpWeight.value = w;
+ float cov = 1.f;
+ if (covars != null) {
+ cov = getCovar(cursor);
+ tmpWeight.setCovariance(cov);
+ }
+ tmpWeight.setTouched(w != 0.f || cov != 1.f);
+ probe.copyFrom(tmpWeight);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java
new file mode 100644
index 0000000..4c21830
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewSparseModel.java
@@ -0,0 +1,197 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock;
+import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock;
+import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.OpenHashMap;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+public final class NewSparseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewSparseModel.class);
+
+ private final OpenHashMap<Object, IWeightValue> weights;
+ private final boolean hasCovar;
+ private boolean clockEnabled;
+
+ public NewSparseModel(int size) {
+ this(size, false);
+ }
+
+ public NewSparseModel(int size, boolean hasCovar) {
+ super();
+ this.weights = new OpenHashMap<Object, IWeightValue>(size);
+ this.hasCovar = hasCovar;
+ this.clockEnabled = false;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return false;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return hasCovar;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ this.clockEnabled = true;
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clockEnabled;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(final Object feature) {
+ return (T) weights.get(feature);
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(final Object feature, final T value) {
+ assert (feature != null);
+ assert (value != null);
+
+ final IWeightValue wrapperValue = wrapIfRequired(value);
+
+ if (clockEnabled && value.isTouched()) {
+ IWeightValue old = weights.get(feature);
+ if (old != null) {
+ short newclock = (short) (old.getClock() + (short) 1);
+ wrapperValue.setClock(newclock);
+ int newDelta = old.getDeltaUpdates() + 1;
+ wrapperValue.setDeltaUpdates((byte) newDelta);
+ }
+ }
+ weights.put(feature, wrapperValue);
+
+ onUpdate(feature, wrapperValue);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ weights.remove(feature);
+ }
+
+ private IWeightValue wrapIfRequired(final IWeightValue value) {
+ final IWeightValue wrapper;
+ if (clockEnabled) {
+ switch (value.getType()) {
+ case NoParams:
+ wrapper = new WeightValueWithClock(value);
+ break;
+ case ParamsCovar:
+ wrapper = new WeightValueWithCovarClock(value);
+ break;
+ case ParamsF1:
+ wrapper = new WeightValueParamsF1Clock(value);
+ break;
+ case ParamsF2:
+ wrapper = new WeightValueParamsF2Clock(value);
+ break;
+ default:
+ throw new IllegalStateException("Unexpected value type: " + value.getType());
+ }
+ } else {
+ wrapper = value;
+ }
+ return wrapper;
+ }
+
+ @Override
+ public float getWeight(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 0.f : v.get();
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ if(weights.containsKey(feature)) {
+ IWeightValue weight = weights.get(feature);
+ weight.set(value);
+ } else {
+ IWeightValue weight = new WeightValue(value);
+ weight.set(value);
+ weights.put(feature, weight);
+ }
+ }
+
+ @Override
+ public float getCovariance(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 1.f : v.getCovariance();
+ }
+
+ @Override
+ protected void _set(final Object feature, final float weight, final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if (w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found " + feature);
+ }
+ w.set(weight);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ }
+
+ @Override
+ protected void _set(final Object feature, final float weight, final float covar,
+ final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if (w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found: " + feature);
+ }
+ w.set(weight);
+ w.setCovariance(covar);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ }
+
+ @Override
+ public int size() {
+ return weights.size();
+ }
+
+ @Override
+ public boolean contains(final Object feature) {
+ return weights.containsKey(feature);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) weights.entries();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/PredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java
index 8d8dd2b..ea82f62 100644
--- a/core/src/main/java/hivemall/model/PredictionModel.java
+++ b/core/src/main/java/hivemall/model/PredictionModel.java
@@ -34,6 +34,9 @@ public interface PredictionModel extends MixedModel {
boolean hasCovariance();
+ void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients);
+
void configureClock();
boolean hasClock();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
index 8b668e7..caa9fea 100644
--- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
+++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
@@ -18,6 +18,8 @@
*/
package hivemall.model;
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.model.WeightValue.WeightValueParamsF2;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
@@ -26,6 +28,7 @@ import hivemall.utils.lang.HalfFloat;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
@@ -38,6 +41,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
private short[] weights;
private short[] covars;
+ // optional value for adagrad
+ private float[] sum_of_squared_gradients;
+ // optional value for adadelta
+ private float[] sum_of_squared_delta_x;
+ // optional value for adagrad+rda
+ private float[] sum_of_gradients;
+
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -58,6 +68,9 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
+ this.sum_of_squared_gradients = null;
+ this.sum_of_squared_delta_x = null;
+ this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -73,6 +86,20 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ if (sum_of_squared_gradients) {
+ this.sum_of_squared_gradients = new float[size];
+ }
+ if (sum_of_squared_delta_x) {
+ this.sum_of_squared_delta_x = new float[size];
+ }
+ if (sum_of_gradients) {
+ this.sum_of_gradients = new float[size];
+ }
+ }
+
+ @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -99,11 +126,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
return HalfFloat.halfFloatToFloat(covars[i]);
}
- private void _setWeight(final int i, final float v) {
- if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
- throw new IllegalArgumentException("Acceptable maximum weight is "
- + HalfFloat.MAX_FLOAT + ": " + v);
- }
+ private void setWeight(final int i, final float v) {
+ HalfFloat.checkRange(v);
weights[i] = HalfFloat.floatToHalfFloat(v);
}
@@ -125,7 +149,16 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
}
- if(clocks != null) {
+ if (sum_of_squared_gradients != null) {
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ }
+ if (sum_of_squared_delta_x != null) {
+ this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
+ }
+ if (sum_of_gradients != null) {
+ this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
+ }
+ if (clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -139,8 +172,17 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
-
- if(covars != null) {
+ if (sum_of_squared_gradients != null) {
+ if (sum_of_squared_delta_x != null) {
+ return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
+ sum_of_squared_delta_x[i]);
+ } else if (sum_of_gradients != null) {
+ return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
+ sum_of_gradients[i]);
+ } else {
+ return (T) new WeightValueParamsF1(getWeight(i), sum_of_squared_gradients[i]);
+ }
+ } else if (covars != null) {
return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
} else {
return (T) new WeightValue(getWeight(i));
@@ -152,13 +194,22 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
- _setWeight(i, weight);
+ setWeight(i, weight);
float covar = 1.f;
boolean hasCovar = value.hasCovariance();
if (hasCovar) {
covar = value.getCovariance();
setCovar(i, covar);
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = value.getSumOfGradients();
+ }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -178,10 +229,19 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return;
}
- _setWeight(i, 0.f);
- if(covars != null) {
+ setWeight(i, 0.f);
+ if (covars != null) {
setCovar(i, 1.f);
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = 0.f;
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = 0.f;
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = 0.f;
+ }
// avoid clock/delta
}
@@ -195,10 +255,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- int i = HiveUtils.parseInt(feature);
- ensureCapacity(i);
- _setWeight(i, value);
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
@@ -214,7 +272,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
clocks[i] = clock;
deltaUpdates[i] = 0;
}
@@ -223,7 +281,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, float covar, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
setCovar(i, covar);
clocks[i] = clock;
deltaUpdates[i] = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index bab982f..f4c4c55 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -36,10 +36,6 @@ public final class SparseModel extends AbstractPredictionModel {
private final boolean hasCovar;
private boolean clockEnabled;
- public SparseModel(int size) {
- this(size, false);
- }
-
public SparseModel(int size, boolean hasCovar) {
super();
this.weights = new OpenHashMap<Object, IWeightValue>(size);
@@ -58,6 +54,10 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
public void configureClock() {
this.clockEnabled = true;
}
@@ -131,15 +131,8 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- if(weights.containsKey(feature)) {
- IWeightValue weight = weights.get(feature);
- weight.set(value);
- } else {
- IWeightValue weight = new WeightValue(value);
- weight.set(value);
- weights.put(feature, weight);
- }
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
index 87e89b6..dcb0bc9 100644
--- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
+++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
@@ -63,6 +63,12 @@ public final class SynchronizedModelWrapper implements PredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ model.configureParams(sum_of_squared_gradients, sum_of_squared_delta_x, sum_of_gradients);
+ }
+
+ @Override
public void configureClock() {
model.configureClock();
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
index 0c964c8..0503145 100644
--- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
@@ -18,12 +18,12 @@
*/
package hivemall.regression;
-import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
index 50dc9b5..93453c1 100644
--- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
@@ -18,14 +18,126 @@
*/
package hivemall.regression;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF2;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
/**
* ADADELTA: AN ADAPTIVE LEARNING RATE METHOD.
+ *
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
*/
@Deprecated
-public final class AdaDeltaUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "train_adadelta_regr",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class AdaDeltaUDTF extends RegressionBaseUDTF {
+
+ private float decay;
+ private float eps;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, true, false);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
+ opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [100]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.decay = 0.95f;
+ this.eps = 1e-6f;
+ this.scaling = 100f;
+ } else {
+ this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f);
+ this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ onlineUpdate(features, gradient);
+ }
+
+ @Override
+ protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
+ final float g_g = gradient * (gradient / scaling);
+
+ for (FeatureValue f : features) {// w[i] += y * x[i]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ IWeightValue old_w = model.get(x);
+ IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
+ model.set(x, new_w);
+ }
+ }
+
+ @Nonnull
+ protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
+ final float gradient, final float g_g) {
+ float old_w = 0.f;
+ float old_scaled_sum_sqgrad = 0.f;
+ float old_sum_squared_delta_x = 0.f;
+ if (old != null) {
+ old_w = old.get();
+ old_scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ old_sum_squared_delta_x = old.getSumOfSquaredDeltaX();
+ }
- public AdaDeltaUDTF() {
- optimizerOptions.put("optimizer", "AdaDelta");
+ float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g);
+ float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps)
+ / (old_scaled_sum_sqgrad * scaling + eps))
+ * gradient;
+ float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x)
+ + ((1.f - decay) * dx * dx);
+ float new_w = old_w + (dx * xi);
+ return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaGradUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
index 4b5f019..87188fc 100644
--- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
@@ -18,14 +18,127 @@
*/
package hivemall.regression;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
/**
* ADAGRAD algorithm with element-wise adaptive learning rates.
+ *
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
*/
@Deprecated
-public final class AdaGradUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "train_adagrad_regr",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class AdaGradUDTF extends RegressionBaseUDTF {
+
+ private float eta;
+ private float eps;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, false, false);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]");
+ opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [100]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.eta = 1.f;
+ this.eps = 1.f;
+ this.scaling = 100f;
+ } else {
+ this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f);
+ this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ onlineUpdate(features, gradient);
+ }
+
+ @Override
+ protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
+ final float g_g = gradient * (gradient / scaling);
+
+ for (FeatureValue f : features) {// w[i] += y * x[i]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ IWeightValue old_w = model.get(x);
+ IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
+ model.set(x, new_w);
+ }
+ }
+
+ @Nonnull
+ protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
+ final float gradient, final float g_g) {
+ float old_w = 0.f;
+ float scaled_sum_sqgrad = 0.f;
+
+ if (old != null) {
+ old_w = old.get();
+ scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ }
+ scaled_sum_sqgrad += g_g;
+
+ float coeff = eta(scaled_sum_sqgrad) * gradient;
+ float new_w = old_w + (coeff * xi);
+ return new WeightValueParamsF1(new_w, scaled_sum_sqgrad);
+ }
- public AdaGradUDTF() {
- optimizerOptions.put("optimizer", "AdaGrad");
+ protected float eta(final double scaledSumOfSquaredGradients) {
+ double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling;
+ //return eta / (float) Math.sqrt(sumOfSquaredGradients);
+ return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
index 2a8b543..21a784e 100644
--- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
@@ -40,6 +40,7 @@ public class GeneralRegressionUDTF extends RegressionBaseUDTF {
protected final Map<String, String> optimizerOptions;
public GeneralRegressionUDTF() {
+ super(true); // This enables new model interfaces
this.optimizerOptions = new HashMap<String, String>();
// Set default values
optimizerOptions.put("optimizer", "adadelta");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/LogressUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java
index ea05da3..78e617d 100644
--- a/core/src/main/java/hivemall/regression/LogressUDTF.java
+++ b/core/src/main/java/hivemall/regression/LogressUDTF.java
@@ -18,12 +18,69 @@
*/
package hivemall.regression;
+import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.LossFunctions;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+/**
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
+ */
@Deprecated
-public final class LogressUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "logress",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class LogressUDTF extends RegressionBaseUDTF {
+
+ private EtaEstimator etaEstimator;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps");
+ opts.addOption("power_t", true,
+ "The exponent for inverse scaling learning rate [default 0.1]");
+ opts.addOption("eta0", true, "The initial learning rate [default 0.1]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+
+ this.etaEstimator = EtaEstimator.get(cl);
+ return cl;
+ }
+
+ @Override
+ protected void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
- public LogressUDTF() {
- optimizerOptions.put("optimizer", "SGD");
- optimizerOptions.put("eta", "fixed");
+ @Override
+ protected float computeGradient(final float target, final float predicted) {
+ float eta = etaEstimator.eta(count);
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ return eta * gradient;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
index e1afe2f..3de56fd 100644
--- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
@@ -18,10 +18,10 @@
*/
package hivemall.regression;
-import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
index 7dc8538..24b0556 100644
--- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
+++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
@@ -72,6 +72,16 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
protected transient Map<Object, FloatAccumulator> accumulated;
protected int sampled;
+ private boolean enableNewModel;
+
+ public RegressionBaseUDTF() {
+ this.enableNewModel = false;
+ }
+
+ public RegressionBaseUDTF(boolean enableNewModel) {
+ this.enableNewModel = enableNewModel;
+ }
+
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
@@ -85,7 +95,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector
: featureInputOI;
- this.model = createModel();
+ this.model = enableNewModel? createNewModel(null) : createModel();
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
new file mode 100644
index 0000000..dd9c4ec
--- /dev/null
+++ b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
@@ -0,0 +1,60 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import static org.junit.Assert.assertEquals;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.HalfFloat;
+
+import java.util.Random;
+
+import org.junit.Test;
+
+public class NewSpaceEfficientNewDenseModelTest {
+
+ @Test
+ public void testGetSet() {
+ final int size = 1 << 12;
+
+ final NewSpaceEfficientDenseModel model1 = new NewSpaceEfficientDenseModel(size);
+ //model1.configureClock();
+ final NewDenseModel model2 = new NewDenseModel(size);
+ //model2.configureClock();
+
+ final Random rand = new Random();
+ for (int t = 0; t < 1000; t++) {
+ int i = rand.nextInt(size);
+ float f = HalfFloat.MAX_FLOAT * rand.nextFloat();
+ IWeightValue w = new WeightValue(f);
+ model1.set(i, w);
+ model2.set(i, w);
+ }
+
+ assertEquals(model2.size(), model1.size());
+
+ IMapIterator<Integer, IWeightValue> itor = model1.entries();
+ while (itor.next() != -1) {
+ int k = itor.getKey();
+ float expected = itor.getValue().get();
+ float actual = model2.getWeight(k);
+ assertEquals(expected, actual, 32f);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java b/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
deleted file mode 100644
index e3a1ed4..0000000
--- a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.model;
-
-import static org.junit.Assert.assertEquals;
-import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.lang.HalfFloat;
-
-import java.util.Random;
-
-import org.junit.Test;
-
-public class SpaceEfficientDenseModelTest {
-
- @Test
- public void testGetSet() {
- final int size = 1 << 12;
-
- final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size);
- //model1.configureClock();
- final DenseModel model2 = new DenseModel(size);
- //model2.configureClock();
-
- final Random rand = new Random();
- for (int t = 0; t < 1000; t++) {
- int i = rand.nextInt(size);
- float f = HalfFloat.MAX_FLOAT * rand.nextFloat();
- IWeightValue w = new WeightValue(f);
- model1.set(i, w);
- model2.set(i, w);
- }
-
- assertEquals(model2.size(), model1.size());
-
- IMapIterator<Integer, IWeightValue> itor = model1.entries();
- while (itor.next() != -1) {
- int k = itor.getKey();
- float expected = itor.getValue().get();
- float actual = model2.getWeight(k);
- assertEquals(expected, actual, 32f);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
----------------------------------------------------------------------
diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
index 38792d8..ec6d556 100644
--- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
+++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
@@ -18,9 +18,9 @@
*/
package hivemall.mix.server;
-import hivemall.model.DenseModel;
+import hivemall.model.NewDenseModel;
import hivemall.model.PredictionModel;
-import hivemall.model.SparseModel;
+import hivemall.model.NewSparseModel;
import hivemall.model.WeightValue;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.client.MixClient;
@@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase {
}
private static void invokeClient(String groupId, int serverPort) throws InterruptedException {
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -298,8 +298,8 @@ public class MixServerTest extends HivemallTestBase {
private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix)
throws InterruptedException {
- PredictionModel model = denseModel ? new DenseModel(100)
- : new SparseModel(100, false);
+ PredictionModel model = denseModel ? new NewDenseModel(100)
+ : new NewSparseModel(100, false);
model.configureClock();
MixClient client = null;
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
index 4fb74f1..c0ee72f 100644
--- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
+++ b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
@@ -23,7 +23,7 @@ import java.util.logging.Logger
import org.scalatest.{BeforeAndAfter, FunSuite}
-import hivemall.model.{DenseModel, PredictionModel, WeightValue}
+import hivemall.model.{NewDenseModel, PredictionModel, WeightValue}
import hivemall.mix.MixMessage.MixEventName
import hivemall.mix.client.MixClient
import hivemall.mix.server.MixServer.ServerState
@@ -95,7 +95,7 @@ class MixServerSuite extends FunSuite with BeforeAndAfter {
ignore(testName) {
val clients = Executors.newCachedThreadPool()
val numClients = nclient
- val models = (0 until numClients).map(i => new DenseModel(ndims, false))
+ val models = (0 until numClients).map(i => new NewDenseModel(ndims, false))
(0 until numClients).map { i =>
clients.submit(new Runnable() {
override def run(): Unit = {