You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:38 UTC
[10/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
new file mode 100644
index 0000000..f56814b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -0,0 +1,334 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Does cross-fold validation of log-likelihood and AUC on several online logistic regression
+ * models. Each record is passed to all but one of the models for training and to the remaining
+ * model for evaluation. In order to maintain proper segregation between the different folds across
+ * training data iterations, data should either be passed to this learner in the same order each
+ * time the training data is traversed or a tracking key such as the file offset of the training
+ * record should be passed with each training example.
+ */
+public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable {
+ private int record;
+ // minimum score to be used for computing log likelihood
+ private static final double MIN_SCORE = 1.0e-50;
+ private OnlineAuc auc = new GlobalOnlineAuc();
+ private double logLikelihood;
+ private final List<OnlineLogisticRegression> models = new ArrayList<>();
+
+ // lambda, learningRate, perTermOffset, perTermExponent
+ private double[] parameters = new double[4];
+ private int numFeatures;
+ private PriorFunction prior;
+ private double percentCorrect;
+
+ private int windowSize = Integer.MAX_VALUE;
+
+ public CrossFoldLearner() {
+ }
+
+ public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
+ this.numFeatures = numFeatures;
+ this.prior = prior;
+ for (int i = 0; i < folds; i++) {
+ OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
+ model.alpha(1).stepOffset(0).decayExponent(0);
+ models.add(model);
+ }
+ }
+
+ // -------- builder-like configuration methods
+
+ public CrossFoldLearner lambda(double v) {
+ for (OnlineLogisticRegression model : models) {
+ model.lambda(v);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner learningRate(double x) {
+ for (OnlineLogisticRegression model : models) {
+ model.learningRate(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner stepOffset(int x) {
+ for (OnlineLogisticRegression model : models) {
+ model.stepOffset(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner decayExponent(double x) {
+ for (OnlineLogisticRegression model : models) {
+ model.decayExponent(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner alpha(double alpha) {
+ for (OnlineLogisticRegression model : models) {
+ model.alpha(alpha);
+ }
+ return this;
+ }
+
+ // -------- training methods
+ @Override
+ public void train(int actual, Vector instance) {
+ train(record, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ record++;
+ int k = 0;
+ for (OnlineLogisticRegression model : models) {
+ if (k == mod(trackingKey, models.size())) {
+ Vector v = model.classifyFull(instance);
+ double score = Math.max(v.get(actual), MIN_SCORE);
+ logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
+
+ int correct = v.maxValueIndex() == actual ? 1 : 0;
+ percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize);
+ if (numCategories() == 2) {
+ auc.addSample(actual, groupKey, v.get(1));
+ }
+ } else {
+ model.train(trackingKey, groupKey, actual, instance);
+ }
+ k++;
+ }
+ }
+
+ private static long mod(long x, int y) {
+ long r = x % y;
+ return r < 0 ? r + y : r;
+ }
+
+ @Override
+ public void close() {
+ for (OnlineLogisticRegression m : models) {
+ m.close();
+ }
+ }
+
+ public void resetLineCounter() {
+ record = 0;
+ }
+
+ public boolean validModel() {
+ boolean r = true;
+ for (OnlineLogisticRegression model : models) {
+ r &= model.validModel();
+ }
+ return r;
+ }
+
+ // -------- classification methods
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector r = new DenseVector(numCategories() - 1);
+ DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+ for (OnlineLogisticRegression model : models) {
+ r.assign(model.classify(instance), scale);
+ }
+ return r;
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector r = new DenseVector(numCategories() - 1);
+ DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+ for (OnlineLogisticRegression model : models) {
+ r.assign(model.classifyNoLink(instance), scale);
+ }
+ return r;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double r = 0;
+ int n = 0;
+ for (OnlineLogisticRegression model : models) {
+ n++;
+ r += model.classifyScalar(instance);
+ }
+ return r / n;
+ }
+
+ // -------- status reporting methods
+
+ @Override
+ public int numCategories() {
+ return models.get(0).numCategories();
+ }
+
+ public double auc() {
+ return auc.auc();
+ }
+
+ public double logLikelihood() {
+ return logLikelihood;
+ }
+
+ public double percentCorrect() {
+ return percentCorrect;
+ }
+
+ // -------- evolutionary optimization
+
+ public CrossFoldLearner copy() {
+ CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior);
+ r.models.clear();
+ for (OnlineLogisticRegression model : models) {
+ model.close();
+ OnlineLogisticRegression newModel =
+ new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior);
+ newModel.copyFrom(model);
+ r.models.add(newModel);
+ }
+ return r;
+ }
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public OnlineAuc getAucEvaluator() {
+ return auc;
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ this.auc = auc;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public List<OnlineLogisticRegression> getModels() {
+ return models;
+ }
+
+ public void addModel(OnlineLogisticRegression model) {
+ models.add(model);
+ }
+
+ public double[] getParameters() {
+ return parameters;
+ }
+
+ public void setParameters(double[] parameters) {
+ this.parameters = parameters;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
+
+ public void setWindowSize(int windowSize) {
+ this.windowSize = windowSize;
+ auc.setWindowSize(windowSize);
+ }
+
+ public PriorFunction getPrior() {
+ return prior;
+ }
+
+ public void setPrior(PriorFunction prior) {
+ this.prior = prior;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(record);
+ PolymorphicWritable.write(out, auc);
+ out.writeDouble(logLikelihood);
+ out.writeInt(models.size());
+ for (OnlineLogisticRegression model : models) {
+ model.write(out);
+ }
+
+ for (double x : parameters) {
+ out.writeDouble(x);
+ }
+ out.writeInt(numFeatures);
+ PolymorphicWritable.write(out, prior);
+ out.writeDouble(percentCorrect);
+ out.writeInt(windowSize);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ record = in.readInt();
+ auc = PolymorphicWritable.read(in, OnlineAuc.class);
+ logLikelihood = in.readDouble();
+ int n = in.readInt();
+ for (int i = 0; i < n; i++) {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression();
+ olr.readFields(in);
+ models.add(olr);
+ }
+ parameters = new double[4];
+ for (int i = 0; i < 4; i++) {
+ parameters[i] = in.readDouble();
+ }
+ numFeatures = in.readInt();
+ prior = PolymorphicWritable.read(in, PriorFunction.class);
+ percentCorrect = in.readDouble();
+ windowSize = in.readInt();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
new file mode 100644
index 0000000..dbf3198
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+
+import org.apache.commons.csv.CSVUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+/**
+ * Converts CSV data lines to vectors.
+ *
+ * Use of this class proceeds in a few steps.
+ * <ul>
+ * <li> At construction time, you tell the class about the target variable and provide
+ * a dictionary of the types of the predictor values. At this point,
+ * the class yet cannot decode inputs because it doesn't know the fields that are in the
+ * data records, nor their order.
+ * <li> Optionally, you tell the parser object about the possible values of the target
+ * variable. If you don't do this then you probably should set the number of distinct
+ * values so that the target variable values will be taken from a restricted range.
+ * <li> Later, when you get a list of the fields, typically from the first line of a CSV
+ * file, you tell the factory about these fields and it builds internal data structures
+ * that allow it to decode inputs. The most important internal state is the field numbers
+ * for various fields. After this point, you can use the factory for decoding data.
+ * <li> To encode data as a vector, you present a line of input to the factory and it
+ * mutates a vector that you provide. The factory also retains trace information so
+ * that it can approximately reverse engineer vectors later.
+ * <li> After converting data, you can ask for an explanation of the data in terms of
+ * terms and weights. In order to explain a vector accurately, the factory needs to
+ * have seen the particular values of categorical fields (typically during encoding vectors)
+ * and needs to have a reasonably small number of collisions in the vector encoding.
+ * </ul>
+ */
+public class CsvRecordFactory implements RecordFactory {
+ private static final String INTERCEPT_TERM = "Intercept Term";
+
+ private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY =
+ ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
+ .put("continuous", ContinuousValueEncoder.class)
+ .put("numeric", ContinuousValueEncoder.class)
+ .put("n", ContinuousValueEncoder.class)
+ .put("word", StaticWordValueEncoder.class)
+ .put("w", StaticWordValueEncoder.class)
+ .put("text", TextValueEncoder.class)
+ .put("t", TextValueEncoder.class)
+ .build();
+
+ private final Map<String, Set<Integer>> traceDictionary = new TreeMap<>();
+
+ private int target;
+ private final Dictionary targetDictionary;
+
+ //Which column is used for identify a CSV file line
+ private String idName;
+ private int id = -1;
+
+ private List<Integer> predictors;
+ private Map<Integer, FeatureVectorEncoder> predictorEncoders;
+ private int maxTargetValue = Integer.MAX_VALUE;
+ private final String targetName;
+ private final Map<String, String> typeMap;
+ private List<String> variableNames;
+ private boolean includeBiasTerm;
+ private static final String CANNOT_CONSTRUCT_CONVERTER =
+ "Unable to construct type converter... shouldn't be possible";
+
+ /**
+ * Parse a single line of CSV-formatted text.
+ *
+ * Separated to make changing this functionality for the entire class easier
+ * in the future.
+ * @param line - CSV formatted text
+ * @return List<String>
+ */
+ private List<String> parseCsvLine(String line) {
+ try {
+ return Arrays.asList(CSVUtils.parseLine(line));
+ }
+ catch (IOException e) {
+ List<String> list = new ArrayList<>();
+ list.add(line);
+ return list;
+ }
+ }
+
+ private List<String> parseCsvLine(CharSequence line) {
+ return parseCsvLine(line.toString());
+ }
+
+ /**
+ * Construct a parser for CSV lines that encodes the parsed data in vector form.
+ * @param targetName The name of the target variable.
+ * @param typeMap A map describing the types of the predictor variables.
+ */
+ public CsvRecordFactory(String targetName, Map<String, String> typeMap) {
+ this.targetName = targetName;
+ this.typeMap = typeMap;
+ targetDictionary = new Dictionary();
+ }
+
+ public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap) {
+ this(targetName, typeMap);
+ this.idName = idName;
+ }
+
+ /**
+ * Defines the values and thus the encoding of values of the target variables. Note
+ * that any values of the target variable not present in this list will be given the
+ * value of the last member of the list.
+ * @param values The values the target variable can have.
+ */
+ @Override
+ public void defineTargetCategories(List<String> values) {
+ Preconditions.checkArgument(
+ values.size() <= maxTargetValue,
+ "Must have less than or equal to " + maxTargetValue + " categories for target variable, but found "
+ + values.size());
+ if (maxTargetValue == Integer.MAX_VALUE) {
+ maxTargetValue = values.size();
+ }
+
+ for (String value : values) {
+ targetDictionary.intern(value);
+ }
+ }
+
+ /**
+ * Defines the number of target variable categories, but allows this parser to
+ * pick encodings for them as they appear.
+ * @param max The number of categories that will be expected. Once this many have been
+ * seen, all others will get the encoding max-1.
+ */
+ @Override
+ public CsvRecordFactory maxTargetValue(int max) {
+ maxTargetValue = max;
+ return this;
+ }
+
+ @Override
+ public boolean usesFirstLineAsSchema() {
+ return true;
+ }
+
+ /**
+ * Processes the first line of a file (which should contain the variable names). The target and
+ * predictor column numbers are set from the names on this line.
+ *
+ * @param line Header line for the file.
+ */
+ @Override
+ public void firstLine(String line) {
+ // read variable names, build map of name -> column
+ final Map<String, Integer> vars = new HashMap<>();
+ variableNames = parseCsvLine(line);
+ int column = 0;
+ for (String var : variableNames) {
+ vars.put(var, column++);
+ }
+
+ // record target column and establish dictionary for decoding target
+ target = vars.get(targetName);
+
+ // record id column
+ if (idName != null) {
+ id = vars.get(idName);
+ }
+
+ // create list of predictor column numbers
+ predictors = new ArrayList<>(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {
+ @Override
+ public Integer apply(String from) {
+ Integer r = vars.get(from);
+ Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
+ return r;
+ }
+ }));
+
+ if (includeBiasTerm) {
+ predictors.add(-1);
+ }
+ Collections.sort(predictors);
+
+ // and map from column number to type encoder for each column that is a predictor
+ predictorEncoders = new HashMap<>();
+ for (Integer predictor : predictors) {
+ String name;
+ Class<? extends FeatureVectorEncoder> c;
+ if (predictor == -1) {
+ name = INTERCEPT_TERM;
+ c = ConstantValueEncoder.class;
+ } else {
+ name = variableNames.get(predictor);
+ c = TYPE_DICTIONARY.get(typeMap.get(name));
+ }
+ try {
+ Preconditions.checkArgument(c != null, "Invalid type of variable %s, wanted one of %s",
+ typeMap.get(name), TYPE_DICTIONARY.keySet());
+ Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
+ Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s", typeMap.get(name));
+ FeatureVectorEncoder encoder = constructor.newInstance(name);
+ predictorEncoders.put(predictor, encoder);
+ encoder.setTraceDictionary(traceDictionary);
+ } catch (InstantiationException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ }
+ }
+ }
+
+
+ /**
+ * Decodes a single line of CSV data and records the target and predictor variables in a record.
+ * As a side effect, features are added into the featureVector. Returns the value of the target
+ * variable.
+ *
+ * @param line The raw data.
+ * @param featureVector Where to fill in the features. Should be zeroed before calling
+ * processLine.
+ * @return The value of the target variable.
+ */
+ @Override
+ public int processLine(String line, Vector featureVector) {
+ List<String> values = parseCsvLine(line);
+
+ int targetValue = targetDictionary.intern(values.get(target));
+ if (targetValue >= maxTargetValue) {
+ targetValue = maxTargetValue - 1;
+ }
+
+ for (Integer predictor : predictors) {
+ String value;
+ if (predictor >= 0) {
+ value = values.get(predictor);
+ } else {
+ value = null;
+ }
+ predictorEncoders.get(predictor).addToVector(value, featureVector);
+ }
+ return targetValue;
+ }
+
+ /***
+ * Decodes a single line of CSV data and records the target(if retrunTarget is true)
+ * and predictor variables in a record. As a side effect, features are added into the featureVector.
+ * Returns the value of the target variable. When used during classify against production data without
+ * target value, the method will be called with returnTarget = false.
+ * @param line The raw data.
+ * @param featureVector Where to fill in the features. Should be zeroed before calling
+ * processLine.
+ * @param returnTarget whether process and return target value, -1 will be returned if false.
+ * @return The value of the target variable.
+ */
+ public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
+ List<String> values = parseCsvLine(line);
+ int targetValue = -1;
+ if (returnTarget) {
+ targetValue = targetDictionary.intern(values.get(target));
+ if (targetValue >= maxTargetValue) {
+ targetValue = maxTargetValue - 1;
+ }
+ }
+
+ for (Integer predictor : predictors) {
+ String value = predictor >= 0 ? values.get(predictor) : null;
+ predictorEncoders.get(predictor).addToVector(value, featureVector);
+ }
+ return targetValue;
+ }
+
+ /***
+ * Extract the raw target string from a line read from a CSV file.
+ * @param line the line of content read from CSV file
+ * @return the raw target value in the corresponding column of CSV line
+ */
+ public String getTargetString(CharSequence line) {
+ List<String> values = parseCsvLine(line);
+ return values.get(target);
+
+ }
+
+ /***
+ * Extract the corresponding raw target label according to a code
+ * @param code the integer code encoded during training process
+ * @return the raw target label
+ */
+ public String getTargetLabel(int code) {
+ for (String key : targetDictionary.values()) {
+ if (targetDictionary.intern(key) == code) {
+ return key;
+ }
+ }
+ return null;
+ }
+
+ /***
+ * Extract the id column value from the CSV record
+ * @param line the line of content read from CSV file
+ * @return the id value of the CSV record
+ */
+ public String getIdString(CharSequence line) {
+ List<String> values = parseCsvLine(line);
+ return values.get(id);
+ }
+
+ /**
+ * Returns a list of the names of the predictor variables.
+ *
+ * @return A list of variable names.
+ */
+ @Override
+ public Iterable<String> getPredictors() {
+ return Lists.transform(predictors, new Function<Integer, String>() {
+ @Override
+ public String apply(Integer v) {
+ if (v >= 0) {
+ return variableNames.get(v);
+ } else {
+ return INTERCEPT_TERM;
+ }
+ }
+ });
+ }
+
+ @Override
+ public Map<String, Set<Integer>> getTraceDictionary() {
+ return traceDictionary;
+ }
+
+ @Override
+ public CsvRecordFactory includeBiasTerm(boolean useBias) {
+ includeBiasTerm = useBias;
+ return this;
+ }
+
+ @Override
+ public List<String> getTargetCategories() {
+ List<String> r = targetDictionary.values();
+ if (r.size() > maxTargetValue) {
+ r.subList(maxTargetValue, r.size()).clear();
+ }
+ return r;
+ }
+
+ public String getIdName() {
+ return idName;
+ }
+
+ public void setIdName(String idName) {
+ this.idName = idName;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
new file mode 100644
index 0000000..f81d8ce
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implements the basic logistic training law.
+ */
+public class DefaultGradient implements Gradient {
+ /**
+ * Provides a default gradient computation useful for logistic regression.
+ *
+ * @param groupKey A grouping key to allow per-something AUC loss to be used for training.
+ * @param actual The target variable value.
+ * @param instance The current feature vector to use for gradient computation
+ * @param classifier The classifier that can compute scores
+ * @return The gradient to be applied to beta
+ */
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ // what does the current model say?
+ Vector v = classifier.classify(instance);
+
+ Vector r = v.like();
+ if (actual != 0) {
+ r.setQuick(actual - 1, 1);
+ }
+ r.assign(v, Functions.MINUS);
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
new file mode 100644
index 0000000..8128370
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements a linear combination of L1 and L2 priors. This can give an
+ * interesting mixture of sparsity and load-sharing between redundant predictors.
+ */
+public class ElasticBandPrior implements PriorFunction {
+ private double alphaByLambda;
+ private L1 l1;
+ private L2 l2;
+
+ // Exists for Writable
+ public ElasticBandPrior() {
+ this(0.0);
+ }
+
+ public ElasticBandPrior(double alphaByLambda) {
+ this.alphaByLambda = alphaByLambda;
+ l1 = new L1();
+ l2 = new L2(1);
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ oldValue *= Math.pow(1 - alphaByLambda * learningRate, generations);
+ double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+ if (newValue * oldValue < 0.0) {
+ // don't allow the value to change sign
+ return 0.0;
+ } else {
+ return newValue;
+ }
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return l1.logP(betaIJ) + alphaByLambda * l2.logP(betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(alphaByLambda);
+ l1.write(out);
+ l2.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ alphaByLambda = in.readDouble();
+ l1 = new L1();
+ l1.readFields(in);
+ l2 = new L2();
+ l2.readFields(in);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
new file mode 100644
index 0000000..524fc06
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Provides the ability to inject a gradient into the SGD logistic regresion.
+ * Typical uses of this are to use a ranking score such as AUC instead of a
+ * normal loss function.
+ */
+public interface Gradient {
+ Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
new file mode 100644
index 0000000..90ef7a8
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Random;
+
+/**
+ * Online gradient machine learner that tries to minimize the label ranking hinge loss.
+ * Implements a gradient machine with one sigmpod hidden layer.
+ * It tries to minimize the ranking loss of some given set of labels,
+ * so this can be used for multi-class, multi-label
+ * or auto-encoding of sparse data (e.g. text).
+ */
+public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // the regularization term, a positive number that controls the size of the weight vector
+ private double regularization = 0.1;
+
+ // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1)
+ private double sparsity = 0.1;
+
+ // the sparsity learning rate.
+ private double sparsityLearningRate = 0.1;
+
+ // the number of features
+ private int numFeatures = 10;
+ // the number of hidden nodes
+ private int numHidden = 100;
+ // the number of output nodes
+ private int numOutput = 2;
+
+ // coefficients for the input to hidden layer.
+ // There are numHidden Vectors of dimension numFeatures.
+ private Vector[] hiddenWeights;
+
+ // coefficients for the hidden to output layer.
+ // There are numOuput Vectors of dimension numHidden.
+ private Vector[] outputWeights;
+
+ // hidden unit bias
+ private Vector hiddenBias;
+
+ // output unit bias
+ private Vector outputBias;
+
+ private final Random rnd;
+
+ public GradientMachine(int numFeatures, int numHidden, int numOutput) {
+ this.numFeatures = numFeatures;
+ this.numHidden = numHidden;
+ this.numOutput = numOutput;
+ hiddenWeights = new DenseVector[numHidden];
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = new DenseVector(numFeatures);
+ hiddenWeights[i].assign(0);
+ }
+ hiddenBias = new DenseVector(numHidden);
+ hiddenBias.assign(0);
+ outputWeights = new DenseVector[numOutput];
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = new DenseVector(numHidden);
+ outputWeights[i].assign(0);
+ }
+ outputBias = new DenseVector(numOutput);
+ outputBias.assign(0);
+ rnd = RandomUtils.getRandom();
+ }
+
+ /**
+ * Initialize weights.
+ *
+ * @param gen random number generator.
+ */
+ public void initWeights(Random gen) {
+ double hiddenFanIn = 1.0 / Math.sqrt(numFeatures);
+ for (int i = 0; i < numHidden; i++) {
+ for (int j = 0; j < numFeatures; j++) {
+ double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
+ hiddenWeights[i].setQuick(j, val);
+ }
+ }
+ double outputFanIn = 1.0 / Math.sqrt(numHidden);
+ for (int i = 0; i < numOutput; i++) {
+ for (int j = 0; j < numHidden; j++) {
+ double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
+ outputWeights[i].setQuick(j, val);
+ }
+ }
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param regularization A positive value that controls the weight vector size.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine regularization(double regularization) {
+ this.regularization = regularization;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param sparsity A value between zero and one that controls the fraction of hidden units
+ * that are activated on average.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine sparsity(double sparsity) {
+ this.sparsity = sparsity;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param sparsityLearningRate New value of initial learning rate for sparsity.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
+ this.sparsityLearningRate = sparsityLearningRate;
+ return this;
+ }
+
+ public void copyFrom(GradientMachine other) {
+ numFeatures = other.numFeatures;
+ numHidden = other.numHidden;
+ numOutput = other.numOutput;
+ learningRate = other.learningRate;
+ regularization = other.regularization;
+ sparsity = other.sparsity;
+ sparsityLearningRate = other.sparsityLearningRate;
+ hiddenWeights = new DenseVector[numHidden];
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = other.hiddenWeights[i].clone();
+ }
+ hiddenBias = other.hiddenBias.clone();
+ outputWeights = new DenseVector[numOutput];
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = other.outputWeights[i].clone();
+ }
+ outputBias = other.outputBias.clone();
+ }
+
+ @Override
+ public int numCategories() {
+ return numOutput;
+ }
+
+ public int numFeatures() {
+ return numFeatures;
+ }
+
+ public int numHidden() {
+ return numHidden;
+ }
+
+ /**
+ * Feeds forward from input to hidden unit..
+ *
+ * @return Hidden unit activations.
+ */
+ public DenseVector inputToHidden(Vector input) {
+ DenseVector activations = new DenseVector(numHidden);
+ for (int i = 0; i < numHidden; i++) {
+ activations.setQuick(i, hiddenWeights[i].dot(input));
+ }
+ activations.assign(hiddenBias, Functions.PLUS);
+ activations.assign(Functions.min(40.0)).assign(Functions.max(-40));
+ activations.assign(Functions.SIGMOID);
+ return activations;
+ }
+
+ /**
+ * Feeds forward from hidden to output
+ *
+ * @return Output unit activations.
+ */
+ public DenseVector hiddenToOutput(Vector hiddenActivation) {
+ DenseVector activations = new DenseVector(numOutput);
+ for (int i = 0; i < numOutput; i++) {
+ activations.setQuick(i, outputWeights[i].dot(hiddenActivation));
+ }
+ activations.assign(outputBias, Functions.PLUS);
+ return activations;
+ }
+
+ /**
+ * Updates using ranking loss.
+ *
+ * @param hiddenActivation the hidden unit's activation
+ * @param goodLabels the labels you want ranked above others.
+ * @param numTrials how many times you want to search for the highest scoring bad label.
+ * @param gen Random number generator.
+ */
+ public void updateRanking(Vector hiddenActivation,
+ Collection<Integer> goodLabels,
+ int numTrials,
+ Random gen) {
+ // All the labels are good, do nothing.
+ if (goodLabels.size() >= numOutput) {
+ return;
+ }
+ for (Integer good : goodLabels) {
+ double goodScore = outputWeights[good].dot(hiddenActivation);
+ int highestBad = -1;
+ double highestBadScore = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < numTrials; i++) {
+ int bad = gen.nextInt(numOutput);
+ while (goodLabels.contains(bad)) {
+ bad = gen.nextInt(numOutput);
+ }
+ double badScore = outputWeights[bad].dot(hiddenActivation);
+ if (badScore > highestBadScore) {
+ highestBadScore = badScore;
+ highestBad = bad;
+ }
+ }
+ int bad = highestBad;
+ double loss = 1.0 - goodScore + highestBadScore;
+ if (loss < 0.0) {
+ continue;
+ }
+ // Note from the loss above the gradient dloss/dy , y being the label is -1 for good
+ // and +1 for bad.
+ // dy / dw is just w since y = x' * w + b.
+ // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w.
+ // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w.
+ // dy / db = 1.
+ Vector gradGood = outputWeights[good].clone();
+ gradGood.assign(Functions.NEGATE);
+ Vector propHidden = gradGood.clone();
+ Vector gradBad = outputWeights[bad].clone();
+ propHidden.assign(gradBad, Functions.PLUS);
+ gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization)));
+ outputWeights[good].assign(gradGood, Functions.PLUS);
+ gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization)));
+ outputWeights[bad].assign(gradBad, Functions.PLUS);
+ outputBias.setQuick(good, outputBias.get(good) + learningRate);
+ outputBias.setQuick(bad, outputBias.get(bad) - learningRate);
+ // Gradient of sigmoid is s * (1 -s).
+ Vector gradSig = hiddenActivation.clone();
+ gradSig.assign(Functions.SIGMOIDGRADIENT);
+ // Multiply by the change caused by the ranking loss.
+ for (int i = 0; i < numHidden; i++) {
+ gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
+ }
+ for (int i = 0; i < numHidden; i++) {
+ for (int j = 0; j < numFeatures; j++) {
+ double v = hiddenWeights[i].get(j);
+ v -= learningRate * (gradSig.get(i) + regularization * v);
+ hiddenWeights[i].setQuick(j, v);
+ }
+ }
+ }
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = classifyNoLink(instance);
+ // Find the max value's index.
+ int max = result.maxValueIndex();
+ result.assign(0);
+ result.setQuick(max, 1.0);
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ DenseVector hidden = inputToHidden(instance);
+ return hiddenToOutput(hidden);
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ Vector output = classifyNoLink(instance);
+ if (output.get(0) > output.get(1)) {
+ return 0;
+ }
+ return 1;
+ }
+
+ public GradientMachine copy() {
+ close();
+ GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeDouble(regularization);
+ out.writeDouble(sparsity);
+ out.writeDouble(sparsityLearningRate);
+ out.writeInt(numFeatures);
+ out.writeInt(numHidden);
+ out.writeInt(numOutput);
+ VectorWritable.writeVector(out, hiddenBias);
+ for (int i = 0; i < numHidden; i++) {
+ VectorWritable.writeVector(out, hiddenWeights[i]);
+ }
+ VectorWritable.writeVector(out, outputBias);
+ for (int i = 0; i < numOutput; i++) {
+ VectorWritable.writeVector(out, outputWeights[i]);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ regularization = in.readDouble();
+ sparsity = in.readDouble();
+ sparsityLearningRate = in.readDouble();
+ numFeatures = in.readInt();
+ numHidden = in.readInt();
+ numOutput = in.readInt();
+ hiddenWeights = new DenseVector[numHidden];
+ hiddenBias = VectorWritable.readVector(in);
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = VectorWritable.readVector(in);
+ }
+ outputWeights = new DenseVector[numOutput];
+ outputBias = VectorWritable.readVector(in);
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = VectorWritable.readVector(in);
+ }
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ Vector hiddenActivation = inputToHidden(instance);
+ hiddenToOutput(hiddenActivation);
+ Collection<Integer> goodLabels = new HashSet<>();
+ goodLabels.add(actual);
+ updateRanking(hiddenActivation, goodLabels, 2, rnd);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
new file mode 100644
index 0000000..28a05f2
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Laplacian or bi-exponential prior. This prior has a strong tendency to set coefficients to zero
+ * and thus is useful as an alternative to variable selection. This version implements truncation which prevents
+ * a coefficient from changing sign. If a correction would change the sign, the coefficient is truncated to zero.
+ *
+ * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP,
+ * the lambda coefficient used to combine the prior with the observations has the same effect. If we had a scale here,
+ * then it would be the same effect as just changing lambda.
+ */
+public class L1 implements PriorFunction {
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+ if (newValue * oldValue < 0) {
+ // don't allow the value to change sign
+ return 0;
+ } else {
+ return newValue;
+ }
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return -Math.abs(betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ // stateless class has nothing to serialize
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ // stateless class has nothing to serialize
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
new file mode 100644
index 0000000..3dfb9fc
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Gaussian prior. This prior has a tendency to decrease large coefficients toward zero, but
+ * doesn't tend to set them to exactly zero.
+ */
+public class L2 implements PriorFunction {
+
+ private static final double HALF_LOG_2PI = Math.log(2.0 * Math.PI) / 2.0;
+
+ private double s2;
+ private double s;
+
+ public L2(double scale) {
+ s = scale;
+ s2 = scale * scale;
+ }
+
+ public L2() {
+ s = 1.0;
+ s2 = 1.0;
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ return oldValue * Math.pow(1.0 - learningRate / s2, generations);
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return -betaIJ * betaIJ / s2 / 2.0 - Math.log(s) - HALF_LOG_2PI;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(s2);
+ out.writeDouble(s);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ s2 = in.readDouble();
+ s = in.readDouble();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
new file mode 100644
index 0000000..a290b22
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * <p>Provides a stochastic mixture of ranking updates and normal logistic updates. This uses a
+ * combination of AUC driven learning to improve ranking performance and traditional log-loss driven
+ * learning to improve log-likelihood.</p>
+ *
+ * <p>See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf</p>
+ *
+ * <p>This implementation only makes sense for the binomial case.</p>
+ */
+public class MixedGradient implements Gradient {
+
+ private final double alpha;
+ private final RankingGradient rank;
+ private final Gradient basic;
+ private final Random random = RandomUtils.getRandom();
+ private boolean hasZero;
+ private boolean hasOne;
+
+ public MixedGradient(double alpha, int window) {
+ this.alpha = alpha;
+ this.rank = new RankingGradient(window);
+ this.basic = this.rank.getBaseGradient();
+ }
+
+ @Override
+ public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ if (random.nextDouble() < alpha) {
+ // one option is to apply a ranking update relative to our recent history
+ if (!hasZero || !hasOne) {
+ throw new IllegalStateException();
+ }
+ return rank.apply(groupKey, actual, instance, classifier);
+ } else {
+ hasZero |= actual == 0;
+ hasOne |= actual == 1;
+ // the other option is a normal update, but we have to update our history on the way
+ rank.addToHistory(actual, instance);
+ return basic.apply(groupKey, actual, instance, classifier);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
new file mode 100644
index 0000000..bcd2ebc
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.Set;
+
+/**
+ * Uses sample data to reverse engineer a feature-hashed model.
+ *
+ * The result gives approximate weights for features and interactions
+ * in the original space.
+ *
+ * The idea is that the hashed encoders have the option of having a trace dictionary. This
+ * tells us where each feature is hashed to, or each feature/value combination in the case
+ * of word-like values. Using this dictionary, we can put values into a synthetic feature
+ * vector in just the locations specified by a single feature or interaction. Then we can
+ * push this through a linear part of a model to see the contribution of that input. For
+ * any generalized linear model like logistic regression, there is a linear part of the
+ * model that allows this.
+ *
+ * What the ModelDissector does is to accept a trace dictionary and a model in an update
+ * method. It figures out the weights for the elements in the trace dictionary and stashes
+ * them. Then in a summary method, the biggest weights are returned. This update/flush
+ * style is used so that the trace dictionary doesn't have to grow to enormous levels,
+ * but instead can be cleared between updates.
+ */
+public class ModelDissector {
+ private final Map<String,Vector> weightMap;
+
+ public ModelDissector() {
+ weightMap = new HashMap<>();
+ }
+
+ /**
+ * Probes a model to determine the effect of a particular variable. This is done
+ * with the ade of a trace dictionary which has recorded the locations in the feature
+ * vector that are modified by various variable values. We can set these locations to
+ * 1 and then look at the resulting score. This tells us the weight the model places
+ * on that variable.
+ * @param features A feature vector to use (destructively)
+ * @param traceDictionary A trace dictionary containing variables and what locations
+ * in the feature vector are affected by them
+ * @param learner The model that we are probing to find weights on features
+ */
+
+ public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
+ // zero out feature vector
+ features.assign(0);
+ for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+ // get a feature and locations where it is stored in the feature vector
+ String key = entry.getKey();
+ Set<Integer> value = entry.getValue();
+
+ // if we haven't looked at this feature yet
+ if (!weightMap.containsKey(key)) {
+ // put probe values in the feature vector
+ for (Integer where : value) {
+ features.set(where, 1);
+ }
+
+ // see what the model says
+ Vector v = learner.classifyNoLink(features);
+ weightMap.put(key, v);
+
+ // and zero out those locations again
+ for (Integer where : value) {
+ features.set(where, 0);
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the n most important features with their
+ * weights, most important category and the top few
+ * categories that they affect.
+ * @param n How many results to return.
+ * @return A list of the top variables.
+ */
+ public List<Weight> summary(int n) {
+ Queue<Weight> pq = new PriorityQueue<>();
+ for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
+ pq.add(new Weight(entry.getKey(), entry.getValue()));
+ while (pq.size() > n) {
+ pq.poll();
+ }
+ }
+ List<Weight> r = new ArrayList<>(pq);
+ Collections.sort(r, Ordering.natural().reverse());
+ return r;
+ }
+
+ private static final class Category implements Comparable<Category> {
+ private final int index;
+ private final double weight;
+
+ private Category(int index, double weight) {
+ this.index = index;
+ this.weight = weight;
+ }
+
+ @Override
+ public int compareTo(Category o) {
+ int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
+ if (r == 0) {
+ if (o.index < index) {
+ return -1;
+ }
+ if (o.index > index) {
+ return 1;
+ }
+ return 0;
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Category)) {
+ return false;
+ }
+ Category other = (Category) o;
+ return index == other.index && weight == other.weight;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashDouble(weight) ^ index;
+ }
+
+ }
+
+ public static class Weight implements Comparable<Weight> {
+ private final String feature;
+ private final double value;
+ private final int maxIndex;
+ private final List<Category> categories;
+
+ public Weight(String feature, Vector weights) {
+ this(feature, weights, 3);
+ }
+
+ public Weight(String feature, Vector weights, int n) {
+ this.feature = feature;
+ // pick out the weight with the largest abs value, but don't forget the sign
+ Queue<Category> biggest = new PriorityQueue<>(n + 1, Ordering.natural());
+ for (Vector.Element element : weights.all()) {
+ biggest.add(new Category(element.index(), element.get()));
+ while (biggest.size() > n) {
+ biggest.poll();
+ }
+ }
+ categories = new ArrayList<>(biggest);
+ Collections.sort(categories, Ordering.natural().reverse());
+ value = categories.get(0).weight;
+ maxIndex = categories.get(0).index;
+ }
+
+ @Override
+ public int compareTo(Weight other) {
+ int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
+ if (r == 0) {
+ return feature.compareTo(other.feature);
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Weight)) {
+ return false;
+ }
+ Weight other = (Weight) o;
+ return feature.equals(other.feature)
+ && value == other.value
+ && maxIndex == other.maxIndex
+ && categories.equals(other.categories);
+ }
+
+ @Override
+ public int hashCode() {
+ return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode();
+ }
+
+ public String getFeature() {
+ return feature;
+ }
+
+ public double getWeight() {
+ return value;
+ }
+
+ public double getWeight(int n) {
+ return categories.get(n).weight;
+ }
+
+ public double getCategory(int n) {
+ return categories.get(n).index;
+ }
+
+ public int getMaxImpact() {
+ return maxIndex;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
new file mode 100644
index 0000000..f89b245
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Provides the ability to store SGD model-related objects as binary files.
+ */
+public final class ModelSerializer {
+
+ // static class ... don't instantiate
+ private ModelSerializer() {
+ }
+
+ public static void writeBinary(String path, CrossFoldLearner model) throws IOException {
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))) {
+ PolymorphicWritable.write(out, model);
+ }
+ }
+
+ public static void writeBinary(String path, OnlineLogisticRegression model) throws IOException {
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))) {
+ PolymorphicWritable.write(out, model);
+ }
+ }
+
+ public static void writeBinary(String path, AdaptiveLogisticRegression model) throws IOException {
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))){
+ PolymorphicWritable.write(out, model);
+ }
+ }
+
+ public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException {
+ DataInput dataIn = new DataInputStream(in);
+ try {
+ return PolymorphicWritable.read(dataIn, clazz);
+ } finally {
+ Closeables.close(in, false);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
new file mode 100644
index 0000000..7a9ca83
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Extends the basic on-line logistic regression learner with a specific set of learning
+ * rate annealing schedules.
+ */
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression implements Writable {
+ public static final int WRITABLE_VERSION = 1;
+
+ // these next two control decayFactor^steps exponential type of annealing
+ // learning rate and decay factor
+ private double mu0 = 1;
+ private double decayFactor = 1 - 1.0e-3;
+
+ // these next two control 1/steps^forget type annealing
+ private int stepOffset = 10;
+ // -1 equals even weighting of all examples, 0 means only use exponential annealing
+ private double forgettingExponent = -0.5;
+
+ // controls how per term annealing works
+ private int perTermAnnealingOffset = 20;
+
+ public OnlineLogisticRegression() {
+ // private constructor available for serialization, but not normal use
+ }
+
+ public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+ this.numCategories = numCategories;
+ this.prior = prior;
+
+ updateSteps = new DenseVector(numFeatures);
+ updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset);
+ beta = new DenseMatrix(numCategories - 1, numFeatures);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param alpha New value of decayFactor, the exponential decay rate for the learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression alpha(double alpha) {
+ this.decayFactor = alpha;
+ return this;
+ }
+
+ @Override
+ public OnlineLogisticRegression lambda(double lambda) {
+ // we only over-ride this to provide a more restrictive return type
+ super.lambda(lambda);
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression learningRate(double learningRate) {
+ this.mu0 = learningRate;
+ return this;
+ }
+
+ public OnlineLogisticRegression stepOffset(int stepOffset) {
+ this.stepOffset = stepOffset;
+ return this;
+ }
+
+ public OnlineLogisticRegression decayExponent(double decayExponent) {
+ if (decayExponent > 0) {
+ decayExponent = -decayExponent;
+ }
+ this.forgettingExponent = decayExponent;
+ return this;
+ }
+
+
+ @Override
+ public double perTermLearningRate(int j) {
+ return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j));
+ }
+
+ @Override
+ public double currentLearningRate() {
+ return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent);
+ }
+
+ public void copyFrom(OnlineLogisticRegression other) {
+ super.copyFrom(other);
+ mu0 = other.mu0;
+ decayFactor = other.decayFactor;
+
+ stepOffset = other.stepOffset;
+ forgettingExponent = other.forgettingExponent;
+
+ perTermAnnealingOffset = other.perTermAnnealingOffset;
+ }
+
+ public OnlineLogisticRegression copy() {
+ close();
+ OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior);
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(mu0);
+ out.writeDouble(getLambda());
+ out.writeDouble(decayFactor);
+ out.writeInt(stepOffset);
+ out.writeInt(step);
+ out.writeDouble(forgettingExponent);
+ out.writeInt(perTermAnnealingOffset);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, beta);
+ PolymorphicWritable.write(out, prior);
+ VectorWritable.writeVector(out, updateCounts);
+ VectorWritable.writeVector(out, updateSteps);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ mu0 = in.readDouble();
+ lambda(in.readDouble());
+ decayFactor = in.readDouble();
+ stepOffset = in.readInt();
+ step = in.readInt();
+ forgettingExponent = in.readDouble();
+ perTermAnnealingOffset = in.readInt();
+ numCategories = in.readInt();
+ beta = MatrixWritable.readMatrix(in);
+ prior = PolymorphicWritable.read(in, PriorFunction.class);
+
+ updateCounts = VectorWritable.readVector(in);
+ updateSteps = VectorWritable.readVector(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
new file mode 100644
index 0000000..c51361c
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Online passive aggressive learner that tries to minimize the label ranking hinge loss.
+ * Implements a multi-class linear classifier minimizing rank loss.
+ * based on "Online passive aggressive algorithms" by Cramer et al, 2006.
+ * Note: Its better to use classifyNoLink because the loss function is based
+ * on ensuring that the score of the good label is larger than the next
+ * highest label by some margin. The conversion to probability is just done
+ * by exponentiating and dividing by the sum and is empirical at best.
+ * Your features should be pre-normalized in some sensible range, for example,
+ * by subtracting the mean and standard deviation, if they are very
+ * different in magnitude from each other.
+ */
+public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // loss statistics.
+ private int lossCount = 0;
+ private double lossSum = 0;
+
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories ) x numFeatures
+ private Matrix weights;
+
+ // number of categories we are classifying.
+ private int numCategories;
+
+ public PassiveAggressive(int numCategories, int numFeatures) {
+ this.numCategories = numCategories;
+ weights = new DenseMatrix(numCategories, numFeatures);
+ weights.assign(0.0);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public PassiveAggressive learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ public void copyFrom(PassiveAggressive other) {
+ learningRate = other.learningRate;
+ numCategories = other.numCategories;
+ weights = other.weights;
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = classifyNoLink(instance);
+ // Convert to probabilities by exponentiation.
+ double max = result.maxValue();
+ result.assign(Functions.minus(max)).assign(Functions.EXP);
+ result = result.divide(result.norm(1));
+
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector result = new DenseVector(weights.numRows());
+ result.assign(0);
+ for (int i = 0; i < weights.numRows(); i++) {
+ result.setQuick(i, weights.viewRow(i).dot(instance));
+ }
+ return result;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double v1 = weights.viewRow(0).dot(instance);
+ double v2 = weights.viewRow(1).dot(instance);
+ v1 = Math.exp(v1);
+ v2 = Math.exp(v2);
+ return v2 / (v1 + v2);
+ }
+
+ public int numFeatures() {
+ return weights.numCols();
+ }
+
+ public PassiveAggressive copy() {
+ close();
+ PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, weights);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ numCategories = in.readInt();
+ weights = MatrixWritable.readMatrix(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ if (lossCount > 1000) {
+ log.info("Avg. Loss = {}", lossSum / lossCount);
+ lossCount = 0;
+ lossSum = 0;
+ }
+ Vector result = classifyNoLink(instance);
+ double myScore = result.get(actual);
+ // Find the highest score that is not actual.
+ int otherIndex = result.maxValueIndex();
+ double otherValue = result.get(otherIndex);
+ if (otherIndex == actual) {
+ result.setQuick(otherIndex, Double.NEGATIVE_INFINITY);
+ otherIndex = result.maxValueIndex();
+ otherValue = result.get(otherIndex);
+ }
+ double loss = 1.0 - myScore + otherValue;
+ lossCount += 1;
+ if (loss >= 0) {
+ lossSum += loss;
+ double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
+ Vector delta = instance.clone();
+ delta.assign(Functions.mult(tau));
+ weights.viewRow(actual).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(actual));
+ delta.assign(Functions.mult(-1));
+ weights.viewRow(otherIndex).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(otherIndex));
+ }
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
new file mode 100644
index 0000000..90062a6
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.ClassUtils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utilities that write a class name and then serialize using writables.
+ */
+public final class PolymorphicWritable {
+
+ private PolymorphicWritable() {
+ }
+
+ public static <T extends Writable> void write(DataOutput dataOutput, T value) throws IOException {
+ dataOutput.writeUTF(value.getClass().getName());
+ value.write(dataOutput);
+ }
+
+ public static <T extends Writable> T read(DataInput dataInput, Class<? extends T> clazz) throws IOException {
+ String className = dataInput.readUTF();
+ T r = ClassUtils.instantiateAs(className, clazz);
+ r.readFields(dataInput);
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
new file mode 100644
index 0000000..857f061
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * A prior is used to regularize the learning algorithm. This allows a trade-off to
+ * be made between complexity of the model being learned and the accuracy with which
+ * the model fits the training data. There are different definitions of complexity
+ * which can be approximated using different priors. For large sparse systems, such
+ * as text classification, the L1 prior is often used which favors sparse models.
+ */
+public interface PriorFunction extends Writable {
+ /**
+ * Applies the regularization to a coefficient.
+ * @param oldValue The previous value.
+ * @param generations The number of generations.
+ * @param learningRate The learning rate with lambda baked in.
+ * @return The new coefficient value after regularization.
+ */
+ double age(double oldValue, double generations, double learningRate);
+
+ /**
+ * Returns the log of the probability of a particular coefficient value according to the prior.
+ * @param betaIJ The coefficient.
+ * @return The log probability.
+ */
+ double logP(double betaIJ);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
new file mode 100644
index 0000000..a04fc8b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * Uses the difference between this instance and recent history to get a
+ * gradient that optimizes ranking performance. Essentially this is the
+ * same as directly optimizing AUC. It isn't expected that this would
+ * be used alone, but rather that a MixedGradient would use it and a
+ * DefaultGradient together to combine both ranking and log-likelihood
+ * goals.
+ */
+public class RankingGradient implements Gradient {
+
+ private static final Gradient BASIC = new DefaultGradient();
+
+ private int window = 10;
+
+ private final List<Deque<Vector>> history = new ArrayList<>();
+
+ public RankingGradient(int window) {
+ this.window = window;
+ }
+
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ addToHistory(actual, instance);
+
+ // now compute average gradient versus saved vectors from the other side
+ Deque<Vector> otherSide = history.get(1 - actual);
+ int n = otherSide.size();
+
+ Vector r = null;
+ for (Vector other : otherSide) {
+ Vector g = BASIC.apply(groupKey, actual, instance.minus(other), classifier);
+
+ if (r == null) {
+ r = g;
+ } else {
+ r.assign(g, Functions.plusMult(1.0 / n));
+ }
+ }
+ return r;
+ }
+
+ public void addToHistory(int actual, Vector instance) {
+ while (history.size() <= actual) {
+ history.add(new ArrayDeque<Vector>(window));
+ }
+ // save this instance
+ Deque<Vector> ourSide = history.get(actual);
+ ourSide.add(instance);
+ while (ourSide.size() >= window) {
+ ourSide.pollFirst();
+ }
+ }
+
+ public Gradient getBaseGradient() {
+ return BASIC;
+ }
+}