You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:02:01 UTC
[11/50] [abbrv] incubator-hivemall git commit: Add optimizer
implementations
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java
new file mode 100644
index 0000000..d11be9b
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@ -0,0 +1,467 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import hivemall.utils.math.MathUtils;
+
+/**
+ * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions
+ */
+public final class LossFunctions {
+
+ public enum LossType {
+ SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss
+ }
+
+ public static LossFunction getLossFunction(String type) {
+ if ("SquaredLoss".equalsIgnoreCase(type)) {
+ return new SquaredLoss();
+ } else if ("LogLoss".equalsIgnoreCase(type)) {
+ return new LogLoss();
+ } else if ("HingeLoss".equalsIgnoreCase(type)) {
+ return new HingeLoss();
+ } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
+ return new SquaredHingeLoss();
+ } else if ("QuantileLoss".equalsIgnoreCase(type)) {
+ return new QuantileLoss();
+ } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
+ return new EpsilonInsensitiveLoss();
+ }
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+
+ public static LossFunction getLossFunction(LossType type) {
+ switch (type) {
+ case SquaredLoss:
+ return new SquaredLoss();
+ case LogLoss:
+ return new LogLoss();
+ case HingeLoss:
+ return new HingeLoss();
+ case SquaredHingeLoss:
+ return new SquaredHingeLoss();
+ case QuantileLoss:
+ return new QuantileLoss();
+ case EpsilonInsensitiveLoss:
+ return new EpsilonInsensitiveLoss();
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+ }
+
+ public interface LossFunction {
+
+ /**
+ * Evaluate the loss function.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The loss evaluated at `p` and `y`.
+ */
+ public float loss(float p, float y);
+
+ public double loss(double p, double y);
+
+ /**
+ * Evaluate the derivative of the loss function with respect to the prediction `p`.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The derivative of the loss function w.r.t. `p`.
+ */
+ public float dloss(float p, float y);
+
+ public boolean forBinaryClassification();
+
+ public boolean forRegression();
+
+ }
+
+ public static abstract class BinaryLoss implements LossFunction {
+
+ protected static void checkTarget(float y) {
+ if (!(y == 1.f || y == -1.f)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ protected static void checkTarget(double y) {
+ if (!(y == 1.d || y == -1.d)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ @Override
+ public boolean forBinaryClassification() {
+ return true;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return false;
+ }
+ }
+
+ public static abstract class RegressionLoss implements LossFunction {
+
+ @Override
+ public boolean forBinaryClassification() {
+ return false;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return true;
+ }
+
+ }
+
+ /**
+ * Squared loss for regression problems.
+ *
+ * If you're trying to minimize the mean error, use squared-loss.
+ */
+ public static final class SquaredLoss extends RegressionLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ return p - y; // 2 (p - y) / 2
+ }
+ }
+
+ /**
+ * Logistic regression loss for binary classification with y in {-1, 1}.
+ */
+ public static final class LogLoss extends BinaryLoss {
+
+ /**
+ * <code>logloss(p,y) = log(1+exp(-p*y))</code>
+ */
+ @Override
+ public float loss(float p, float y) {
+ checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z) * -y;
+ }
+ if (z < -18.f) {
+ return -y;
+ }
+ return -y / ((float) Math.exp(z) + 1.f);
+ }
+ }
+
+ /**
+ * Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class HingeLoss extends BinaryLoss {
+
+ private float threshold;
+
+ public HingeLoss() {
+ this(1.f);
+ }
+
+ /**
+ * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM.
+ * When threshold=0.0, one gets the loss used by the Perceptron.
+ */
+ public HingeLoss(float threshold) {
+ this.threshold = threshold;
+ }
+
+ public void setThreshold(float threshold) {
+ this.threshold = threshold;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = hingeLoss(p, y, threshold);
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? -y : 0.f;
+ }
+ }
+
+ /**
+ * Squared Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class SquaredHingeLoss extends BinaryLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float d = 1 - (y * p);
+ return (d > 0.f) ? -2.f * d * y : 0.f;
+ }
+
+ }
+
+ /**
+ * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase
+ * as long as you get the relative order correct.
+ *
+ * @link http://en.wikipedia.org/wiki/Quantile_regression
+ */
+ public static final class QuantileLoss extends RegressionLoss {
+
+ private float tau;
+
+ public QuantileLoss() {
+ this.tau = 0.5f;
+ }
+
+ public QuantileLoss(float tau) {
+ setTau(tau);
+ }
+
+ public void setTau(float tau) {
+ if (tau <= 0 || tau >= 1.0) {
+ throw new IllegalArgumentException("tau must be in range (0, 1): " + tau);
+ }
+ this.tau = tau;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float e = y - p;
+ if (e > 0.f) {
+ return tau * e;
+ } else {
+ return -(1.f - tau) * e;
+ }
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double e = y - p;
+ if (e > 0.d) {
+ return tau * e;
+ } else {
+ return -(1.d - tau) * e;
+ }
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float e = y - p;
+ if (e == 0.f) {
+ return 0.f;
+ }
+ return (e > 0.f) ? -tau : (1.f - tau);
+ }
+
+ }
+
+ /**
+ * Epsilon-Insensitive loss used by Support Vector Regression (SVR).
+ * <code>loss = max(0, |y - p| - epsilon)</code>
+ */
+ public static final class EpsilonInsensitiveLoss extends RegressionLoss {
+
+ private float epsilon;
+
+ public EpsilonInsensitiveLoss() {
+ this(0.1f);
+ }
+
+ public EpsilonInsensitiveLoss(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ public void setEpsilon(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ if ((y - p) > epsilon) {// real value > predicted value - epsilon
+ return -1.f;
+ }
+ if ((p - y) > epsilon) {// real value < predicted value - epsilon
+ return 1.f;
+ }
+ return 0.f;
+ }
+
+ }
+
+ public static float logisticLoss(final float target, final float predicted) {
+ if (predicted > -100.d) {
+ return target - (float) MathUtils.sigmoid(predicted);
+ } else {
+ return target;
+ }
+ }
+
+ public static float logLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ public static double logLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ public static float squaredLoss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ public static double squaredLoss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ public static float hingeLoss(final float p, final float y, final float threshold) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ return threshold - z;
+ }
+
+ public static double hingeLoss(final double p, final double y, final double threshold) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ return threshold - z;
+ }
+
+ public static float hingeLoss(float p, float y) {
+ return hingeLoss(p, y, 1.f);
+ }
+
+ public static double hingeLoss(double p, double y) {
+ return hingeLoss(p, y, 1.d);
+ }
+
+ public static float squaredHingeLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ float d = 1.f - z;
+ return (d > 0.f) ? (d * d) : 0.f;
+ }
+
+ public static double squaredHingeLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ double d = 1.d - z;
+ return (d > 0.d) ? d * d : 0.d;
+ }
+
+ /**
+ * Math.abs(target - predicted) - epsilon
+ */
+ public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
+ return Math.abs(target - predicted) - epsilon;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java
new file mode 100644
index 0000000..863536c
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -0,0 +1,246 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import java.util.Map;
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import hivemall.model.WeightValue;
+import hivemall.model.IWeightValue;
+
+public interface Optimizer {
+
+ /**
+ * Update the weights of models thru this interface.
+ */
+ float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient);
+
+ // Count up #step to tune learning rate
+ void proceedStep();
+
+ static abstract class OptimizerBase implements Optimizer {
+
+ protected final EtaEstimator etaImpl;
+ protected final Regularization regImpl;
+
+ protected int numStep = 1;
+
+ public OptimizerBase(final Map<String, String> options) {
+ this.etaImpl = EtaEstimator.get(options);
+ this.regImpl = Regularization.get(options);
+ }
+
+ @Override
+ public void proceedStep() {
+ numStep++;
+ }
+
+ // Directly update a given `weight` in terms of performance
+ protected void computeUpdateValue(
+ @Nonnull final IWeightValue weight, float gradient) {
+ float delta = computeUpdateValueImpl(weight, regImpl.regularize(weight.get(), gradient));
+ weight.set(weight.get() - etaImpl.eta(numStep) * delta);
+ }
+
+ // Compute a delta to update
+ protected float computeUpdateValueImpl(
+ @Nonnull final IWeightValue weight, float gradient) {
+ return gradient;
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class SGD extends OptimizerBase {
+
+ private final IWeightValue weightValueReused;
+
+ public SGD(final Map<String, String> options) {
+ super(options);
+ this.weightValueReused = new WeightValue(0.f);
+ }
+
+ @Override
+ public float computeUpdatedValue(
+ @Nonnull Object feature, float weight, float gradient) {
+ computeUpdateValue(weightValueReused, gradient);
+ return weightValueReused.get();
+ }
+
+ }
+
+ static abstract class AdaDelta extends OptimizerBase {
+
+ private final float decay;
+ private final float eps;
+ private final float scale;
+
+ public AdaDelta(Map<String, String> options) {
+ super(options);
+ float decay = 0.95f;
+ float eps = 1e-6f;
+ float scale = 100.0f;
+ if(options.containsKey("decay")) {
+ decay = Float.parseFloat(options.get("decay"));
+ }
+ if(options.containsKey("eps")) {
+ eps = Float.parseFloat(options.get("eps"));
+ }
+ if(options.containsKey("scale")) {
+ scale = Float.parseFloat(options.get("scale"));
+ }
+ this.decay = decay;
+ this.eps = eps;
+ this.scale = scale;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients();
+ float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX();
+ float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * gradient * (gradient / scale));
+ float delta = (float) Math.sqrt((old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * scale + eps)) * gradient;
+ float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + ((1.f - decay) * delta * delta);
+ weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
+ weight.setSumOfSquaredDeltaX(new_sum_squared_delta_x);
+ return delta;
+ }
+
+ }
+
+ static abstract class AdaGrad extends OptimizerBase {
+
+ private final float eps;
+ private final float scale;
+
+ public AdaGrad(Map<String, String> options) {
+ super(options);
+ float eps = 1.0f;
+ float scale = 100.0f;
+ if(options.containsKey("eps")) {
+ eps = Float.parseFloat(options.get("eps"));
+ }
+ if(options.containsKey("scale")) {
+ scale = Float.parseFloat(options.get("scale"));
+ }
+ this.eps = eps;
+ this.scale = scale;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient * (gradient / scale);
+ float delta = gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps);
+ weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
+ return delta;
+ }
+
+ }
+
+ /**
+ * Adam, an algorithm for first-order gradient-based optimization of stochastic objective
+ * functions, based on adaptive estimates of lower-order moments.
+ *
+ * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." arXiv preprint arXiv:1412.6980v8, 2014.
+ */
+ static abstract class Adam extends OptimizerBase {
+
+ private final float beta;
+ private final float gamma;
+ private final float eps_hat;
+
+ public Adam(Map<String, String> options) {
+ super(options);
+ float beta = 0.9f;
+ float gamma = 0.999f;
+ float eps_hat = 1e-8f;
+ if(options.containsKey("beta")) {
+ beta = Float.parseFloat(options.get("beta"));
+ }
+ if(options.containsKey("gamma")) {
+ gamma = Float.parseFloat(options.get("gamma"));
+ }
+ if(options.containsKey("eps_hat")) {
+ eps_hat = Float.parseFloat(options.get("eps_hat"));
+ }
+ this.beta = beta;
+ this.gamma = gamma;
+ this.eps_hat = eps_hat;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float val_m = beta * weight.getM() + (1.f - beta) * gradient;
+ float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0));
+ float val_m_hat = val_m / (float) (1.f - Math.pow(beta, numStep));
+ float val_v_hat = val_v / (float) (1.f - Math.pow(gamma, numStep));
+ float delta = val_m_hat / (float) (Math.sqrt(val_v_hat) + eps_hat);
+ weight.setM(val_m);
+ weight.setV(val_v);
+ return delta;
+ }
+
+ }
+
+ static abstract class RDA extends OptimizerBase {
+
+ private final OptimizerBase optimizerImpl;
+
+ private final float lambda;
+
+ public RDA(final OptimizerBase optimizerImpl, Map<String, String> options) {
+ super(options);
+ // We assume `optimizerImpl` has the `AdaGrad` implementation only
+ if(!(optimizerImpl instanceof AdaGrad)) {
+ throw new IllegalArgumentException(
+ optimizerImpl.getClass().getSimpleName()
+ + " currently does not support RDA regularization");
+ }
+ float lambda = 1e-6f;
+ if(options.containsKey("lambda")) {
+ lambda = Float.parseFloat(options.get("lambda"));
+ }
+ this.optimizerImpl = optimizerImpl;
+ this.lambda = lambda;
+ }
+
+ @Override
+ protected void computeUpdateValue(@Nonnull final IWeightValue weight, float gradient) {
+ float new_sum_grad = weight.getSumOfGradients() + gradient;
+ // sign(u_{t,i})
+ float sign = (new_sum_grad > 0.f)? 1.f : -1.f;
+ // |u_{t,i}|/t - \lambda
+ float meansOfGradients = (sign * new_sum_grad / numStep) - lambda;
+ if(meansOfGradients < 0.f) {
+ // x_{t,i} = 0
+ weight.set(0.f);
+ weight.setSumOfSquaredGradients(0.f);
+ weight.setSumOfGradients(0.f);
+ } else {
+ // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda)
+ float new_weight = -1.f * sign * etaImpl.eta(numStep) * numStep * optimizerImpl.computeUpdateValueImpl(weight, meansOfGradients);
+ weight.set(new_weight);
+ weight.setSumOfGradients(new_sum_grad);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Regularization.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java
new file mode 100644
index 0000000..ce1ef7f
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/Regularization.java
@@ -0,0 +1,99 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import javax.annotation.Nonnull;
+import java.util.Map;
+
+public abstract class Regularization {
+
+ protected final float lambda;
+
+ public Regularization(final Map<String, String> options) {
+ float lambda = 1e-6f;
+ if(options.containsKey("lambda")) {
+ lambda = Float.parseFloat(options.get("lambda"));
+ }
+ this.lambda = lambda;
+ }
+
+ abstract float regularize(float weight, float gradient);
+
+ public static final class PassThrough extends Regularization {
+
+ public PassThrough(final Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient;
+ }
+
+ }
+
+ public static final class L1 extends Regularization {
+
+ public L1(Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient + lambda * (weight > 0.f? 1.f : -1.f);
+ }
+
+ }
+
+ public static final class L2 extends Regularization {
+
+ public L2(final Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient + lambda * weight;
+ }
+
+ }
+
+ @Nonnull
+ public static Regularization get(@Nonnull final Map<String, String> options)
+ throws IllegalArgumentException {
+ final String regName = options.get("regularization");
+ if (regName == null) {
+ return new PassThrough(options);
+ }
+ if(regName.toLowerCase().equals("no")) {
+ return new PassThrough(options);
+ } else if(regName.toLowerCase().equals("l1")) {
+ return new L1(options);
+ } else if(regName.toLowerCase().equals("l2")) {
+ return new L2(options);
+ } else if(regName.toLowerCase().equals("rda")) {
+ // Return `PassThrough` because we need special handling for RDA.
+ // See an implementation of `Optimizer#RDA`.
+ return new PassThrough(options);
+ } else {
+ throw new IllegalArgumentException("Unsupported regularization name: " + regName);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
new file mode 100644
index 0000000..a74d0da
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
@@ -0,0 +1,171 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import hivemall.optimizer.Optimizer.OptimizerBase;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue;
+import hivemall.utils.collections.OpenHashMap;
+
+public final class SparseOptimizerFactory {
+ private static final Log logger = LogFactory.getLog(SparseOptimizerFactory.class);
+
+ @Nonnull
+ public static Optimizer create(int ndims, @Nonnull Map<String, String> options) {
+ final String optimizerName = options.get("optimizer");
+ if(optimizerName != null) {
+ OptimizerBase optimizerImpl;
+ if(optimizerName.toLowerCase().equals("sgd")) {
+ optimizerImpl = new Optimizer.SGD(options);
+ } else if(optimizerName.toLowerCase().equals("adadelta")) {
+ optimizerImpl = new AdaDelta(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adagrad")) {
+ optimizerImpl = new AdaGrad(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adam")) {
+ optimizerImpl = new Adam(ndims, options);
+ } else {
+ throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
+ }
+
+ logger.info("set " + optimizerImpl.getClass().getSimpleName()
+ + " as an optimizer: " + options);
+
+ // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
+ if(options.get("regularization") != null
+ && options.get("regularization").toLowerCase().equals("rda")) {
+ optimizerImpl = new RDA(ndims, optimizerImpl, options);
+ }
+
+ return optimizerImpl;
+ }
+ throw new IllegalArgumentException("`optimizer` not defined");
+ }
+
+ @NotThreadSafe
+ static final class AdaDelta extends Optimizer.AdaDelta {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public AdaDelta(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class AdaGrad extends Optimizer.AdaGrad {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public AdaGrad(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class Adam extends Optimizer.Adam {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public Adam(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class RDA extends Optimizer.RDA {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public RDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) {
+ super(optimizerImpl, options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
index b81a4bf..0c964c8 100644
--- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
index e807340..50dc9b5 100644
--- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
@@ -18,123 +18,14 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
-import hivemall.model.FeatureValue;
-import hivemall.model.IWeightValue;
-import hivemall.model.WeightValue.WeightValueParamsF2;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
/**
* ADADELTA: AN ADAPTIVE LEARNING RATE METHOD.
*/
-@Description(
- name = "train_adadelta_regr",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class AdaDeltaUDTF extends RegressionBaseUDTF {
-
- private float decay;
- private float eps;
- private float scaling;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- StructObjectInspector oi = super.initialize(argOIs);
- model.configureParams(true, true, false);
- return oi;
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
- opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]");
- opts.addOption("scale", true,
- "Internal scaling/descaling factor for cumulative weights [100]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
- if (cl == null) {
- this.decay = 0.95f;
- this.eps = 1e-6f;
- this.scaling = 100f;
- } else {
- this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f);
- this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f);
- this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
- }
- return cl;
- }
-
- @Override
- protected final void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
- float gradient = LossFunctions.logisticLoss(target, predicted);
- onlineUpdate(features, gradient);
- }
-
- @Override
- protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
- final float g_g = gradient * (gradient / scaling);
-
- for (FeatureValue f : features) {// w[i] += y * x[i]
- if (f == null) {
- continue;
- }
- Object x = f.getFeature();
- float xi = f.getValueAsFloat();
-
- IWeightValue old_w = model.get(x);
- IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
- model.set(x, new_w);
- }
- }
-
- @Nonnull
- protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
- final float gradient, final float g_g) {
- float old_w = 0.f;
- float old_scaled_sum_sqgrad = 0.f;
- float old_sum_squared_delta_x = 0.f;
- if (old != null) {
- old_w = old.get();
- old_scaled_sum_sqgrad = old.getSumOfSquaredGradients();
- old_sum_squared_delta_x = old.getSumOfSquaredDeltaX();
- }
+@Deprecated
+public final class AdaDeltaUDTF extends GeneralRegressionUDTF {
- float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g);
- float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps)
- / (old_scaled_sum_sqgrad * scaling + eps))
- * gradient;
- float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x)
- + ((1.f - decay) * dx * dx);
- float new_w = old_w + (dx * xi);
- return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x);
+ public AdaDeltaUDTF() {
+ optimizerOptions.put("optimizer", "AdaDelta");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaGradUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
index de48d97..4b5f019 100644
--- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
@@ -18,124 +18,14 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
-import hivemall.model.FeatureValue;
-import hivemall.model.IWeightValue;
-import hivemall.model.WeightValue.WeightValueParamsF1;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
/**
* ADAGRAD algorithm with element-wise adaptive learning rates.
*/
-@Description(
- name = "train_adagrad_regr",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class AdaGradUDTF extends RegressionBaseUDTF {
-
- private float eta;
- private float eps;
- private float scaling;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- StructObjectInspector oi = super.initialize(argOIs);
- model.configureParams(true, false, false);
- return oi;
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]");
- opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
- opts.addOption("scale", true,
- "Internal scaling/descaling factor for cumulative weights [100]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
- if (cl == null) {
- this.eta = 1.f;
- this.eps = 1.f;
- this.scaling = 100f;
- } else {
- this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f);
- this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f);
- this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
- }
- return cl;
- }
-
- @Override
- protected final void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
- float gradient = LossFunctions.logisticLoss(target, predicted);
- onlineUpdate(features, gradient);
- }
-
- @Override
- protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
- final float g_g = gradient * (gradient / scaling);
-
- for (FeatureValue f : features) {// w[i] += y * x[i]
- if (f == null) {
- continue;
- }
- Object x = f.getFeature();
- float xi = f.getValueAsFloat();
-
- IWeightValue old_w = model.get(x);
- IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
- model.set(x, new_w);
- }
- }
-
- @Nonnull
- protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
- final float gradient, final float g_g) {
- float old_w = 0.f;
- float scaled_sum_sqgrad = 0.f;
-
- if (old != null) {
- old_w = old.get();
- scaled_sum_sqgrad = old.getSumOfSquaredGradients();
- }
- scaled_sum_sqgrad += g_g;
-
- float coeff = eta(scaled_sum_sqgrad) * gradient;
- float new_w = old_w + (coeff * xi);
- return new WeightValueParamsF1(new_w, scaled_sum_sqgrad);
- }
+@Deprecated
+public final class AdaGradUDTF extends GeneralRegressionUDTF {
- protected float eta(final double scaledSumOfSquaredGradients) {
- double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling;
- //return eta / (float) Math.sqrt(sumOfSquaredGradients);
- return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0
+ public AdaGradUDTF() {
+ optimizerOptions.put("optimizer", "AdaGrad");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
new file mode 100644
index 0000000..2a8b543
--- /dev/null
+++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
@@ -0,0 +1,125 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.regression;
+
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+import hivemall.optimizer.LossFunctions;
+import hivemall.model.FeatureValue;
+
+/**
+ * A general regression class with replaceable optimization functions.
+ */
+public class GeneralRegressionUDTF extends RegressionBaseUDTF {
+
+ protected final Map<String, String> optimizerOptions;
+
+ public GeneralRegressionUDTF() {
+ this.optimizerOptions = new HashMap<String, String>();
+ // Set default values
+ optimizerOptions.put("optimizer", "adadelta");
+ optimizerOptions.put("eta", "fixed");
+ optimizerOptions.put("eta0", "1.0");
+ optimizerOptions.put("t", "10000");
+ optimizerOptions.put("power_t", "0.1");
+ optimizerOptions.put("eps", "1e-6");
+ optimizerOptions.put("rho", "0.95");
+ optimizerOptions.put("scale", "100.0");
+ optimizerOptions.put("lambda", "1.0");
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if(argOIs.length != 2 && argOIs.length != 3) {
+ throw new UDFArgumentException(
+ this.getClass().getSimpleName()
+ + " takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target "
+ + "[, constant string options]");
+ }
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("optimizer", "opt", true, "Optimizer to update weights [default: adadelta]");
+ opts.addOption("eta", true, " ETA estimator to compute delta [default: fixed]");
+ opts.addOption("eta0", true, "Initial learning rate [default 1.0]");
+ opts.addOption("t", "total_steps", true, "Total of n_samples * epochs time steps [default: 10000]");
+ opts.addOption("power_t", true, "Exponent for inverse scaling learning rate [default 0.1]");
+ opts.addOption("eps", true, "Denominator value of AdaDelta/AdaGrad [default 1e-6]");
+ opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
+ opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]");
+ opts.addOption("regularization", "reg", true, "Regularization type [default not-defined]");
+ opts.addOption("lambda", true, "Regularization term on weights [default 1.0]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final CommandLine cl = super.processOptions(argOIs);
+ if(cl != null) {
+ for(final Option opt: cl.getOptions()) {
+ optimizerOptions.put(opt.getOpt(), opt.getValue());
+ }
+ }
+ return cl;
+ }
+
+ @Override
+ protected Map<String, String> getOptimzierOptions() {
+ return optimizerOptions;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if(target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, final float target,
+ final float predicted) {
+ if(is_mini_batch) {
+ throw new UnsupportedOperationException(
+ this.getClass().getSimpleName() + " supports no `is_mini_batch` mode");
+ } else {
+ float loss = LossFunctions.logisticLoss(target, predicted);
+ for(FeatureValue f : features) {
+ Object feature = f.getFeature();
+ float xi = f.getValueAsFloat();
+ float weight = model.getWeight(feature);
+ float new_weight = optimizerImpl.computeUpdatedValue(feature, weight, -loss * xi);
+ model.setWeight(feature, new_weight);
+ }
+ optimizerImpl.proceedStep();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/LogressUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java
index ca3da71..ea05da3 100644
--- a/core/src/main/java/hivemall/regression/LogressUDTF.java
+++ b/core/src/main/java/hivemall/regression/LogressUDTF.java
@@ -18,65 +18,12 @@
*/
package hivemall.regression;
-import hivemall.common.EtaEstimator;
-import hivemall.common.LossFunctions;
+@Deprecated
+public final class LogressUDTF extends GeneralRegressionUDTF {
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
-@Description(
- name = "logress",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class LogressUDTF extends RegressionBaseUDTF {
-
- private EtaEstimator etaEstimator;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- return super.initialize(argOIs);
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps");
- opts.addOption("power_t", true,
- "The exponent for inverse scaling learning rate [default 0.1]");
- opts.addOption("eta0", true, "The initial learning rate [default 0.1]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
-
- this.etaEstimator = EtaEstimator.get(cl);
- return cl;
- }
-
- @Override
- protected void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected float computeUpdate(final float target, final float predicted) {
- float eta = etaEstimator.eta(count);
- float gradient = LossFunctions.logisticLoss(target, predicted);
- return eta * gradient;
+ public LogressUDTF() {
+ optimizerOptions.put("optimizer", "SGD");
+ optimizerOptions.put("eta", "fixed");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
index c089946..e1afe2f 100644
--- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
index 561d4f7..7dc8538 100644
--- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
+++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
@@ -25,6 +25,7 @@ import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.Optimizer;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.FloatAccumulator;
@@ -64,6 +65,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
private boolean parseFeature;
protected PredictionModel model;
+ protected Optimizer optimizerImpl;
protected int count;
// The accumulated delta of each weight values.
@@ -87,6 +89,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}
+ this.optimizerImpl = createOptimizer();
this.count = 0;
this.sampled = 0;
@@ -235,7 +238,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
protected void update(@Nonnull final FeatureValue[] features, final float target,
final float predicted) {
- final float grad = computeUpdate(target, predicted);
+ final float grad = computeGradient(target, predicted);
if (is_mini_batch) {
accumulateUpdate(features, grad);
@@ -247,12 +250,9 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
}
}
- protected float computeUpdate(float target, float predicted) {
- throw new IllegalStateException();
- }
-
- protected IWeightValue getNewWeight(IWeightValue old_w, float delta) {
- throw new IllegalStateException();
+ // Compute a gradient by using a loss function in derived classes
+ protected float computeGradient(float target, float predicted) {
+ throw new UnsupportedOperationException();
}
protected final void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/test/java/hivemall/optimizer/OptimizerTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/optimizer/OptimizerTest.java b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
new file mode 100644
index 0000000..cfcfa79
--- /dev/null
+++ b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
@@ -0,0 +1,172 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public final class OptimizerTest {
+
+ @Test
+ public void testIllegalOptimizer() {
+ try {
+ final Map<String, String> emptyOptions = new HashMap<String, String>();
+ DenseOptimizerFactory.create(1024, emptyOptions);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "illegal");
+ DenseOptimizerFactory.create(1024, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> emptyOptions = new HashMap<String, String>();
+ SparseOptimizerFactory.create(1024, emptyOptions);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "illegal");
+ SparseOptimizerFactory.create(1024, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ }
+
+ @Test
+ public void testOptimizerFactory() {
+ final Map<String, String> options = new HashMap<String, String>();
+ final String[] regTypes = new String[] {"NO", "L1", "L2"};
+ for(final String regType : regTypes) {
+ options.put("optimizer", "SGD");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof Optimizer.SGD);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof Optimizer.SGD);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "AdaDelta");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaDelta);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaDelta);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "AdaGrad");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaGrad);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaGrad);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "Adam");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.Adam);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.Adam);
+ }
+
+ // We need special handling for `Optimizer#RDA`
+ options.put("optimizer", "AdaGrad");
+ options.put("regularization", "RDA");
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.RDA);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.RDA);
+
+ // `SGD`, `AdaDelta`, and `Adam` currently does not support `RDA`
+ for(final String optimizerType : new String[] {"SGD", "AdaDelta", "Adam"}) {
+ options.put("optimizer", optimizerType);
+ try {
+ DenseOptimizerFactory.create(8, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ SparseOptimizerFactory.create(8, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ }
+ }
+
+ private void testUpdateWeights(Optimizer optimizer, int numUpdates, int initSize) {
+ final float[] weights = new float[initSize * 2];
+ final Random rnd = new Random();
+ try {
+ for(int i = 0; i < numUpdates; i++) {
+ int index = rnd.nextInt(initSize);
+ weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f);
+ }
+ for(int i = 0; i < numUpdates; i++) {
+ int index = rnd.nextInt(initSize * 2);
+ weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f);
+ }
+ } catch(Exception e) {
+ Assert.fail("failed to update weights: " + e.getMessage());
+ }
+ }
+
+ private void testOptimizer(final Map<String, String> options, int numUpdates, int initSize) {
+ final Map<String, String> testOptions = new HashMap<String, String>(options);
+ final String[] regTypes = new String[] {"NO", "L1", "L2", "RDA"};
+ for(final String regType : regTypes) {
+ options.put("regularization", regType);
+ testUpdateWeights(DenseOptimizerFactory.create(1024, testOptions), 65536, 1024);
+ testUpdateWeights(SparseOptimizerFactory.create(1024, testOptions), 65536, 1024);
+ }
+ }
+
+ @Test
+ public void testSGDOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "SGD");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdaDeltaOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "AdaDelta");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdaGradOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "AdaGrad");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdamOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "Adam");
+ testOptimizer(options, 65536, 1024);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
----------------------------------------------------------------------
diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
index 0b1455c..38792d8 100644
--- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
+++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
@@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase {
}
private static void invokeClient(String groupId, int serverPort) throws InterruptedException {
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -296,10 +296,10 @@ public class MixServerTest extends HivemallTestBase {
serverExec.shutdown();
}
- private static void invokeClient01(String groupId, int serverPort, boolean denseModel,
- boolean cancelMix) throws InterruptedException {
- PredictionModel model = denseModel ? new DenseModel(100, false) : new SparseModel(100,
- false);
+ private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix)
+ throws InterruptedException {
+ PredictionModel model = denseModel ? new DenseModel(100)
+ : new SparseModel(100, false);
model.configureClock();
MixClient client = null;
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index bab5a29..ccdace0 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -13,6 +13,9 @@ CREATE FUNCTION hivemall_version as 'hivemall.HivemallVersionUDF' USING JAR '${h
-- binary classification --
---------------------------
+DROP FUNCTION IF EXISTS train_classifier;
+CREATE FUNCTION train_classifier as 'hivemall.classifier.GeneralClassifierUDTF' USING JAR '${hivemall_jar}';
+
DROP FUNCTION IF EXISTS train_perceptron;
CREATE FUNCTION train_perceptron as 'hivemall.classifier.PerceptronUDTF' USING JAR '${hivemall_jar}';
@@ -45,7 +48,7 @@ CREATE FUNCTION train_adagrad_rda as 'hivemall.classifier.AdaGradRDAUDTF' USING
--------------------------------
-- Multiclass classification --
---------------------------------
+--------------------------------
DROP FUNCTION IF EXISTS train_multiclass_perceptron;
CREATE FUNCTION train_multiclass_perceptron as 'hivemall.classifier.multiclass.MulticlassPerceptronUDTF' USING JAR '${hivemall_jar}';
@@ -312,6 +315,13 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem
-- Regression functions --
--------------------------
+DROP FUNCTION IF EXISTS train_regression;
+CREATE FUNCTION train_regression as 'hivemall.classifier.GeneralRegressionUDTF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS train_logregr;
+CREATE FUNCTION train_logregr as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}';
+
+-- alias for backward compatibility
DROP FUNCTION IF EXISTS logress;
CREATE FUNCTION logress as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}';
@@ -599,3 +609,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U
DROP FUNCTION xgboost_multiclass_predict;
CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}';
+=======
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 315b4d2..d60fd7f 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -9,6 +9,9 @@ create temporary function hivemall_version as 'hivemall.HivemallVersionUDF';
-- binary classification --
---------------------------
+drop temporary function train_classifier;
+create temporary function train_classifier as 'hivemall.regression.GeneralClassifierUDTF';
+
drop temporary function train_perceptron;
create temporary function train_perceptron as 'hivemall.classifier.PerceptronUDTF';
@@ -308,6 +311,13 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF';
-- Regression functions --
--------------------------
+drop temporary function train_regression;
+create temporary function train_regression as 'hivemall.regression.GeneralRegressionUDTF';
+
+drop temporary function train_logregr;
+create temporary function train_logregr as 'hivemall.regression.LogressUDTF';
+
+-- alias for backward compatibility
drop temporary function logress;
create temporary function logress as 'hivemall.regression.LogressUDTF';
@@ -628,5 +638,3 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
-