You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/07/25 01:38:30 UTC

svn commit: r978948 [3/4] - in /mahout/trunk: conf/ core/ core/doc/ core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/evaluation/ core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org...

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Maps;
+import com.google.gson.*;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+
+import java.io.*;
+import java.lang.reflect.Type;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Encapsulates everything we need to know about a model and how it reads and vectorizes its input.
+ * This encapsulation allows us to coherently save and restore a model from a file.  This also
+ * allows us to keep command line arguments that affect learning in a coherent way.
+ */
+public class LogisticModelParameters {
+  private String targetVariable;
+  private Map<String, String> typeMap;
+  private int numFeatures;
+  private boolean useBias;
+  private int maxTargetCategories;
+  private List<String> targetCategories = null;
+  private double lambda;
+  private double learningRate;
+  private transient CsvRecordFactory csv = null;
+  private OnlineLogisticRegression lr = null;
+
+  public LogisticModelParameters() {
+  }
+
+  /**
+   * Returns a CsvRecordFactory compatible with this logistic model.  The reason that this is tied
+   * in here is so that we have access to the list of target categories when it comes time to save
+   * the model.  If the input isn't CSV, then calling setTargetCategories before calling saveTo will
+   * suffice.
+   *
+   * @return
+   */
+  public CsvRecordFactory getCsvRecordFactory() {
+    if (csv == null) {
+      csv = new CsvRecordFactory(getTargetVariable(), getTypeMap())
+              .maxTargetValue(getMaxTargetCategories())
+              .includeBiasTerm(useBias());
+      if (targetCategories != null) {
+        csv.defineTargetCategories(targetCategories);
+      }
+    }
+    return csv;
+  }
+
+  /**
+   * Creates a logistic regression trainer using the parameters collected here.
+   *
+   * @return
+   */
+  public OnlineLogisticRegression createRegression() {
+    if (lr == null) {
+      lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1())
+              .lambda(getLambda())
+              .learningRate(getLearningRate())
+              .alpha(1 - 1e-3);
+    }
+    return lr;
+  }
+
+  public static void saveModel(Writer out, OnlineLogisticRegression model, List<String> targetCategories) throws IOException {
+    LogisticModelParameters x = new LogisticModelParameters();
+    x.setTargetCategories(targetCategories);
+    x.setLambda(model.getLambda());
+    x.setLearningRate(model.currentLearningRate());
+    x.setNumFeatures(model.numFeatures());
+    x.setUseBias(true);
+    x.setTargetCategories(targetCategories);
+    x.saveTo(out);
+  }
+
+  /**
+   * Saves a model in JSON format.  This includes the current state of the logistic regression
+   * trainer and the dictionary for the target categories.
+   *
+   * @param out Where to write the model.
+   * @throws IOException
+   */
+  public void saveTo(Writer out) throws IOException {
+    if (lr != null) {
+      lr.close();
+    }
+    targetCategories = csv.getTargetCategories();
+    GsonBuilder gb = new GsonBuilder();
+    gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
+    Gson gson = gb.setPrettyPrinting().create();
+
+    String savedForm = gson.toJson(this);
+    out.write(savedForm);
+  }
+
+  /**
+   * Reads a model in JSON format.
+   *
+   * @param in Where to read the model from.
+   * @return The LogisticModelParameters object that we read.
+   */
+  public static LogisticModelParameters loadFrom(Reader in) {
+    GsonBuilder gb = new GsonBuilder();
+    gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
+    return gb.create().fromJson(in, LogisticModelParameters.class);
+  }
+
+  /**
+   * Reads a model in JSON format from a File.
+   *
+   * @param in Where to read the model from.
+   * @return The LogisticModelParameters object that we read.
+   * @throws IOException If there is an error opening or closing the file.
+   */
+  public static LogisticModelParameters loadFrom(File in) throws IOException {
+    FileReader input = new FileReader(in);
+    LogisticModelParameters r = loadFrom(input);
+    input.close();
+    return r;
+  }
+
+  /**
+   * Sets the types of the predictors.  This will later be used when reading CSV data.  If you don't
+   * use the CSV data and convert to vectors on your own, you don't need to call this.
+   *
+   * @param predictorList The list of variable names.
+   * @param typeList      The list of types in the format preferred by CsvRecordFactory.
+   */
+  public void setTypeMap(List predictorList, List typeList) {
+    typeMap = Maps.newHashMap();
+    if (typeList.size() == 0) {
+      throw new IllegalArgumentException("Must have at least one type specifier");
+    }
+    Iterator iTypes = typeList.iterator();
+    String lastType = null;
+    for (Object x : predictorList) {
+      // type list can be short .. we just repeat last spec
+      if (iTypes.hasNext()) {
+        lastType = iTypes.next().toString();
+      }
+      typeMap.put(x.toString(), lastType);
+    }
+  }
+
+  /**
+   * Sets the target variable.  If you don't use the CSV record factory, then this is irrelevant.
+   *
+   * @param targetVariable
+   */
+  public void setTargetVariable(String targetVariable) {
+    this.targetVariable = targetVariable;
+  }
+
+  /**
+   * Sets the number of target categories to be considered.
+   *
+   * @param maxTargetCategories
+   */
+  public void setMaxTargetCategories(int maxTargetCategories) {
+    this.maxTargetCategories = maxTargetCategories;
+  }
+
+  public void setNumFeatures(int numFeatures) {
+    this.numFeatures = numFeatures;
+  }
+
+  public void setTargetCategories(List<String> targetCategories) {
+    this.targetCategories = targetCategories;
+    maxTargetCategories = targetCategories.size();
+  }
+
+  public void setUseBias(boolean useBias) {
+    this.useBias = useBias;
+  }
+
+  public boolean useBias() {
+    return useBias;
+  }
+
+  public String getTargetVariable() {
+    return targetVariable;
+  }
+
+  public Map<String, String> getTypeMap() {
+    return typeMap;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public int getMaxTargetCategories() {
+    return maxTargetCategories;
+  }
+
+  public double getLambda() {
+    return lambda;
+  }
+
+  public void setLambda(double lambda) {
+    this.lambda = lambda;
+  }
+
+  public double getLearningRate() {
+    return learningRate;
+  }
+
+  public void setLearningRate(double learningRate) {
+    this.learningRate = learningRate;
+  }
+
+  /**
+   * Tells GSON how to (de)serialize a Mahout matrix.  We assume on deserialization that
+   * the matrix is dense.
+   */
+  public static class MatrixTypeAdapter implements JsonDeserializer<Matrix>, JsonSerializer<Matrix>, InstanceCreator<Matrix> {
+    @Override
+    public JsonElement serialize(Matrix m, Type type, JsonSerializationContext jsonSerializationContext) {
+      JsonObject r = new JsonObject();
+      r.add("rows", new JsonPrimitive(m.numRows()));
+      r.add("cols", new JsonPrimitive(m.numCols()));
+      JsonArray v = new JsonArray();
+      for (int row = 0; row < m.numRows(); row++) {
+        JsonArray rowData = new JsonArray();
+        for (int col = 0; col < m.numCols(); col++) {
+          rowData.add(new JsonPrimitive(m.get(row, col)));
+        }
+        v.add(rowData);
+      }
+      r.add("data", v);
+      return r;
+    }
+
+    @Override
+    public Matrix deserialize(JsonElement x, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+      JsonObject data = x.getAsJsonObject();
+      Matrix r = new DenseMatrix(data.get("rows").getAsInt(), data.get("cols").getAsInt());
+      int i = 0;
+      for (JsonElement row : data.get("data").getAsJsonArray()) {
+        int j = 0;
+        for (JsonElement element : row.getAsJsonArray()) {
+          r.set(i, j, element.getAsDouble());
+          j++;
+        }
+        i++;
+      }
+      return r;
+    }
+
+    @Override
+    public Matrix createInstance(Type type) {
+      return new DenseMatrix();
+    }
+  }
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+
+/**
+ * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead
+ * of processing the input, this class just prints the input to standard out.
+ */
+public class PrintResourceOrFile {
+  public static void main(String[] args) throws IOException {
+    if (args.length != 1) {
+      throw new IllegalArgumentException("Must have a single argument that names a file or resource.");
+    }
+    BufferedReader in = TrainLogistic.InputOpener.open(args[0]);
+    String line = in.readLine();
+    while (line != null) {
+      System.out.println(line);
+      line = in.readLine();
+    }
+  }
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+
+/**
+ *
+ */
+public class RunLogistic {
+  private static final Logger log = LoggerFactory.getLogger(RunLogistic.class);
+  private static String inputFile;
+  private static String modelFile;
+  private static boolean showAuc = false;
+  private static boolean showScores = false;
+  private static boolean showConfusion = false;
+
+  public static void main(String[] args) throws IOException {
+    if (parseArgs(args)) {
+      if (!showAuc && !showConfusion && !showScores) {
+        showAuc = true;
+        showConfusion = true;
+      }
+
+      Auc collector = new Auc();
+      LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
+
+      CsvRecordFactory csv = lmp.getCsvRecordFactory();
+      OnlineLogisticRegression lr = lmp.createRegression();
+      BufferedReader in = TrainLogistic.InputOpener.open(inputFile);
+      String line = in.readLine();
+      csv.firstLine(line);
+      line = in.readLine();
+      if (showScores) {
+        System.out.printf("\"%s\",\"%s\",\"%s\"\n", "target", "model-output", "log-likelihood");
+      }
+      while (line != null) {
+        Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+        int target = csv.processLine(line, v);
+        double score = lr.classifyScalar(v);
+        if (showScores) {
+          System.out.printf("%d,%.3f,%.6f\n", target, score, lr.logLikelihood(target, v));
+        }
+        collector.add(target, score);
+        line = in.readLine();
+      }
+
+      if (showAuc) {
+        System.out.printf("AUC = %.2f\n", collector.auc());
+      }
+      if (showConfusion) {
+        Matrix m = collector.confusion();
+        System.out.printf("confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+        m = collector.entropy();
+        System.out.printf("entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+      }
+    }
+  }
+
+  private static boolean parseArgs(String[] args) {
+        DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help").withDescription("print this list").create();
+
+    Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+
+    Option auc = builder.withLongName("auc").withDescription("print AUC").create();
+    Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
+
+    Option scores = builder.withLongName("scores").withDescription("print scores").create();
+
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFile = builder.withLongName("input")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+            .withDescription("where to get training data")
+            .create();
+
+    Option modelFile = builder.withLongName("model")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+            .withDescription("where to get a model")
+            .create();
+
+    Group normalArgs = new GroupBuilder()
+            .withOption(help)
+            .withOption(quiet)
+            .withOption(auc)
+            .withOption(scores)
+            .withOption(confusion)
+            .withOption(inputFile)
+            .withOption(modelFile)
+            .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine;
+    cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    RunLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+    RunLogistic.modelFile = getStringArgument(cmdLine, modelFile);
+    RunLogistic.showAuc = getBooleanArgument(cmdLine, auc);
+    RunLogistic.showScores = getBooleanArgument(cmdLine, scores);
+    RunLogistic.showConfusion = getBooleanArgument(cmdLine, confusion);
+
+    return true;
+  }
+
+  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.net.URL;
+import java.util.List;
+
+
+/**
+ * Train a logistic regression for the examples from Chapter 13 of Mahout in Action
+ */
+public class TrainLogistic {
+  private static final Logger log = LoggerFactory.getLogger(TrainLogistic.class);
+  private static String inputFile;
+  private static String outputFile;
+  private static LogisticModelParameters lmp;
+
+  private static int passes;
+  private static boolean scores = false;
+
+
+  public static void main(String[] args) throws IOException {
+    if (parseArgs(args)) {
+      double logPEstimate = 0;
+      int samples = 0;
+
+      CsvRecordFactory csv = lmp.getCsvRecordFactory();
+      OnlineLogisticRegression lr = lmp.createRegression();
+      for (int pass = 0; pass < passes; pass++) {
+        BufferedReader in = InputOpener.open(inputFile);
+
+        // read variable names
+        csv.firstLine(in.readLine());
+
+        String line = in.readLine();
+        while (line != null) {
+          // for each new line, get target and predictors
+          Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+          int targetValue = csv.processLine(line, input);
+
+          // check performance while this is still news
+          double logP = lr.logLikelihood(targetValue, input);
+          if (!Double.isInfinite(logP)) {
+            if (samples < 20) {
+              logPEstimate = (samples * logPEstimate + logP) / (samples + 1);
+            } else {
+              logPEstimate = 0.95 * logPEstimate + 0.05 * logP;
+            }
+            samples++;
+          }
+          double p = lr.classifyScalar(input);
+          if (scores) {
+            System.out.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n", samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate);
+          }
+
+          // now update model
+          lr.train(targetValue, input);
+
+          line = in.readLine();
+        }
+        in.close();
+      }
+
+      FileWriter modelOutput = new FileWriter(outputFile);
+      lmp.saveTo(modelOutput);
+      modelOutput.close();
+      
+      System.out.printf("%d\n", lmp.getNumFeatures());
+      System.out.printf("%s ~ ", lmp.getTargetVariable());
+      String sep = "";
+      for (String v : csv.getPredictors()) {
+        double weight = predictorWeight(lr, 0, csv, v);
+        if (weight != 0) {
+          System.out.printf("%s%.3f*%s", sep, weight, v);
+          sep = " + ";
+        }
+      }
+      System.out.printf("\n");
+      for (int row = 0; row < lr.getBeta().numRows(); row++) {
+        for (String key : csv.getTraceDictionary().keySet()) {
+          double weight = predictorWeight(lr, row, csv, key);
+          if (weight != 0) {
+            System.out.printf("%20s %.5f\n", key, weight);
+          }
+        }
+        for (int column = 0; column < lr.getBeta().numCols(); column++) {
+          System.out.printf("%15.9f ", lr.getBeta().get(row, column));
+        }
+        System.out.println();
+      }
+    }
+  }
+
+  private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+    double weight = 0;
+    for (Integer column : csv.getTraceDictionary().get(predictor)) {
+      weight += lr.getBeta().get(row, column);
+    }
+    return weight;
+  }
+
+  private static boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help").withDescription("print this list").create();
+
+    Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
+    Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create();
+
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFile = builder.withLongName("input")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+            .withDescription("where to get training data")
+            .create();
+
+    Option outputFile = builder.withLongName("output")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+            .withDescription("where to get training data")
+            .create();
+
+    Option predictors = builder.withLongName("predictors")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("p").create())
+            .withDescription("a list of predictor variables")
+            .create();
+
+    Option types = builder.withLongName("types")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("t").create())
+            .withDescription("a list of predictor variable types (numeric, word, or text)")
+            .create();
+
+    Option target = builder.withLongName("target")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("target").withMaximum(1).create())
+            .withDescription("the name of the target variable")
+            .create();
+
+    Option features = builder.withLongName("features")
+            .withArgument(
+                    argumentBuilder.withName("numFeatures")
+                            .withDefault("1000")
+                            .withMaximum(1).create())
+            .withDescription("the number of internal hashed features to use")
+            .create();
+
+    Option passes = builder.withLongName("passes")
+            .withArgument(
+                    argumentBuilder.withName("passes")
+                            .withDefault("2")
+                            .withMaximum(1).create())
+            .withDescription("the number of times to pass over the input data")
+            .create();
+
+    Option lambda = builder.withLongName("lambda")
+            .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create())
+            .withDescription("the amount of coefficient decay to use")
+            .create();
+
+    Option rate = builder.withLongName("rate")
+            .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create())
+            .withDescription("the learning rate")
+            .create();
+
+    Option noBias = builder.withLongName("noBias")
+            .withDescription("don't include a bias term")
+            .create();
+
+    Option targetCategories = builder.withLongName("categories")
+            .withRequired(true)
+            .withArgument(argumentBuilder.withName("number").withMaximum(1).create())
+            .withDescription("the number of target categories to be considered")
+            .create();
+
+    Group normalArgs = new GroupBuilder()
+            .withOption(help)
+            .withOption(quiet)
+            .withOption(inputFile)
+            .withOption(outputFile)
+            .withOption(target)
+            .withOption(targetCategories)
+            .withOption(predictors)
+            .withOption(types)
+            .withOption(passes)
+            .withOption(lambda)
+            .withOption(rate)
+            .withOption(noBias)
+            .withOption(features)
+            .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine;
+    cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+    TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile);
+
+    List<String> typeList = Lists.newArrayList();
+    for (Object x : cmdLine.getValues(types)) {
+      typeList.add(x.toString());
+    }
+
+    List<String> predictorList = Lists.newArrayList();
+    for (Object x : cmdLine.getValues(predictors)) {
+      predictorList.add(x.toString());
+    }
+
+    lmp = new LogisticModelParameters();
+    lmp.setTargetVariable(getStringArgument(cmdLine, target));
+    lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+    lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+    lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
+    lmp.setTypeMap(predictorList, typeList);
+
+    lmp.setLambda(getDoubleArgument(cmdLine, lambda));
+    lmp.setLearningRate(getDoubleArgument(cmdLine, rate));
+
+    TrainLogistic.scores = getBooleanArgument(cmdLine, scores);
+    TrainLogistic.passes = getIntegerArgument(cmdLine, passes);
+
+    return true;
+  }
+
+  private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+
+  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+    return Integer.parseInt((String) cmdLine.getValue(features));
+  }
+
+  private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+    return Double.parseDouble((String) cmdLine.getValue(op));
+  }
+
+  public static class InputOpener {
+    public static BufferedReader open(String inputFile) throws IOException {
+      InputStreamReader s;
+      try {
+        URL resource = Resources.getResource(inputFile);
+        s = new InputStreamReader(resource.openStream());
+      } catch (IllegalArgumentException e) {
+        s = new FileReader(inputFile);
+      }
+
+      return new BufferedReader(s);
+    }
+  }
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TwentyNewsGroupTrain.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TwentyNewsGroupTrain.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TwentyNewsGroupTrain.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TwentyNewsGroupTrain.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Splitter;
+import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multiset;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.tokenattributes.TermAttribute;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.*;
+import java.util.*;
+
+/**
+ * Simple training program that reads newsgroup articles, one per file and trains an SGD model using
+ * that data.
+ */
+public class TwentyNewsGroupTrain {
+  private static final int FEATURES = 200000;
+  private static final int PASSES = 1;
+  private static Splitter onColon = Splitter.on(":").trimResults();
+
+
+  public static void main(String[] args) throws IOException {
+    File base = new File(args[0]);
+
+    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+    RecordValueEncoder encoder = new StaticWordValueEncoder("body");
+    RecordValueEncoder bias = new ConstantValueEncoder("Intercept");
+    bias.setTraceDictionary(traceDictionary);
+    bias.setTraceDictionary(traceDictionary);
+    RecordValueEncoder lines = new ConstantValueEncoder("Lines");
+    RecordValueEncoder logLines = new ConstantValueEncoder("LogLines");
+    encoder.setProbes(2);
+    encoder.setTraceDictionary(traceDictionary);
+
+    OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression(20, FEATURES, new L1())
+            .alpha(1)
+            .stepOffset(1000)
+            .decayExponent(0.9)
+            .lambda(0)
+            .learningRate(10);
+
+    Dictionary newsGroups = new Dictionary();
+
+    List<File> files = Lists.newArrayList();
+    for (File newsgroup : base.listFiles()) {
+      newsGroups.intern(newsgroup.getName());
+      files.addAll(Arrays.asList(newsgroup.listFiles()));
+    }
+    System.out.printf("%d files\n", files.size());
+
+    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
+    Random rand = new Random();
+    double averageLL = 0;
+    double averageCorrect = 0;
+    double averageLineCount = 0;
+
+    int k = 0;
+    double step = 0;
+    int[] bumps = new int[]{1, 2, 5};
+    PrintWriter lineCounts = new PrintWriter(new File("lineCounts.tsv"));
+    for (File file : permute(files, rand)) {
+      BufferedReader reader = new BufferedReader(new FileReader(file));
+
+      String ng = file.getParentFile().getName();
+      int actual = newsGroups.intern(ng);
+
+      Multiset<String> words = ConcurrentHashMultiset.create();
+
+      double lineCount = averageLineCount;
+
+      // read headers
+      String line = reader.readLine();
+      while (line != null && line.length() > 0) {
+        if (line.startsWith("Lines:")) {
+          String count = Lists.newArrayList(onColon.split(line)).get(1);
+          try {
+            lineCount = Integer.parseInt(count);
+            averageLineCount = averageLineCount + (lineCount - averageLineCount) / Math.min(k + 1, 1000);
+            lineCounts.printf("%s\t%.1f\n", ng, lineCount);
+          } catch (NumberFormatException e) {
+            // ignore bogus data, use average value
+          }
+        }
+
+        boolean countHeader = (
+                false
+//                 ||line.startsWith("From:")
+//                        ||line.startsWith("Subject:")
+//                        ||line.startsWith("Keywords:")
+                        ||line.startsWith("Summary:")
+                );
+        do {
+          StringReader in = new StringReader(line);
+          if (countHeader) {
+            countWords(analyzer, words, in);
+          }
+          line = reader.readLine();
+        } while (line.startsWith(" "));
+      }
+
+      // read body of document
+      //      countWords(analyzer, words, reader);
+      reader.close();
+
+      // now encode words as vector
+      Vector v = new RandomAccessSparseVector(FEATURES);
+
+      // encode constant term
+      bias.addToVector(null, 1, v);
+
+      lines.addToVector(null, lineCount / 30, v);
+      logLines.addToVector(null, Math.log(lineCount + 1), v);
+
+      // and then all other words
+      for (String word : words.elementSet()) {
+        encoder.addToVector(word, Math.log(1 + words.count(word)), v);
+      }
+
+      double ll = learningAlgorithm.logLikelihood(actual, v);
+      averageLL = (Math.min(k, 100) * averageLL + ll) / (Math.min(k, 100) + 1);
+      Vector p = new DenseVector(20);
+      learningAlgorithm.classifyFull(p, v);
+      int estimated = p.maxValueIndex();
+      boolean correct = estimated == actual;
+      averageCorrect = (Math.min(k, 500) * averageCorrect + (correct ? 1 : 0)) / (Math.min(k, 500) + 1);
+      learningAlgorithm.train(actual, v);
+
+      k++;
+      if (k % (bumps[(int) Math.floor(step) % bumps.length] * Math.pow(10, Math.floor(step / bumps.length))) == 0) {
+        step += 0.25;
+        if (estimated == -1) {
+          System.out.printf("%d\n", estimated);
+        }
+        System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n",
+                k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated));
+
+      }
+      lineCounts.close();
+    }
+
+    learningAlgorithm.close();
+
+    GsonBuilder gb = new GsonBuilder();
+    gb.registerTypeAdapter(Matrix.class, new LogisticModelParameters.MatrixTypeAdapter());
+    Gson gson = gb.setPrettyPrinting().create();
+
+    Writer output = new FileWriter("model");
+
+    Model x = new Model();
+    x.lr = learningAlgorithm;
+    x.targetCategories = newsGroups.values();
+    gson.toJson(x, output);
+
+    output.close();
+  }
+
+  private static void checkVector(Vector v) {
+    Iterator<Vector.Element> i = v.iterateNonZero();
+    while (i.hasNext()) {
+      Vector.Element element = i.next();
+      if (Double.isInfinite(element.get()) || Double.isNaN(element.get())) {
+        System.out.printf("Found invalid value at %d: %.0f\n", element.index(), element.get());
+      }
+    }
+  }
+
+  private static class Model {
+    OnlineLogisticRegression lr;
+    List<String> targetCategories;
+  }
+
+  private static void countWords(Analyzer analyzer, Multiset<String> words, Reader in) throws IOException {
+    TokenStream ts = analyzer.tokenStream("body", in);
+    TermAttribute termAtt = ts.addAttribute(TermAttribute.class);
+    while (ts.incrementToken()) {
+      char[] termBuffer = termAtt.termBuffer();
+      int termLen = termAtt.termLength();
+      words.add(new String(termBuffer, 0, termLen));
+    }
+  }
+
+  private static <T> List<T> permute(Iterable<T> values, Random rand) {
+    ArrayList<T> r = Lists.newArrayList();
+    for (T value : values) {
+      int i = rand.nextInt(r.size() + 1);
+      if (i < r.size()) {
+        T t = r.get(i);
+        r.set(i, value);
+        r.add(t);
+      } else {
+        r.add(value);
+      }
+    }
+    return r;
+  }
+}

Added: mahout/trunk/examples/src/main/resources/donut-test.csv
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/resources/donut-test.csv?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/resources/donut-test.csv (added)
+++ mahout/trunk/examples/src/main/resources/donut-test.csv Sat Jul 24 23:38:28 2010
@@ -0,0 +1,41 @@
+"x","y","shape","color","xx","xy","yy","c","a","b"
+0.802415437065065,0.0978854028508067,21,2,0.643870533640319,0.07854475831082,0.00958155209126472,0.503141377562721,0.808363832523192,0.220502180491382
+0.97073650965467,0.989339149091393,23,2,0.942329371176533,0.96038763245370,0.978791951924881,0.67900343471543,1.38604520961670,0.989771844311643
+0.566630310611799,0.369259539060295,25,1,0.321069908904024,0.209233647314105,0.136352607187021,0.146740132271139,0.676330182744379,0.569352171215186
+0.377948862500489,0.500907538458705,24,1,0.142845342665413,0.189317434378387,0.250908362084759,0.122054511555201,0.62749797190921,0.79865886318828
+0.0133881184738129,0.269793515326455,25,2,0.000179241716268851,0.00361202754665705,0.0727885409122062,0.538317888266967,0.270125494221621,1.02283505301727
+0.395229484187439,0.385281964903697,25,1,0.156206345171069,0.152274792255611,0.148442192480054,0.155361155247979,0.551949760078871,0.717070128562224
+0.757145672803745,0.416044564917684,21,1,0.573269569845435,0.315006342020941,0.173093079997545,0.270503996498299,0.863922826323613,0.481737796145881
+0.589166145538911,0.971624446567148,24,2,0.347116747049177,0.572448230095344,0.944054065166917,0.479979395505718,1.13629697360157,1.05491161769044
+0.843438957352191,0.218833807157353,25,2,0.711389274779351,0.184572958142208,0.0478882351549814,0.443852166182378,0.871365313708512,0.269071728782402
+0.628562391968444,0.801476288354024,25,2,0.395090680597092,0.503777852913796,0.642364240793743,0.327744170151609,1.01855531091386,0.8833629703887
+0.262267543468624,0.247060472844169,22,2,0.0687842643570668,0.0647959433010369,0.0610388772419841,0.347124077652729,0.360309785599907,0.778002605819416
+0.738417695043609,0.562460686312988,21,1,0.545260692353516,0.415330923539883,0.316362023647678,0.246463657857698,0.928236347058869,0.620312280963368
+0.498857178725302,0.164454092038795,21,1,0.248858484765768,0.0820391043843046,0.0270451483883046,0.335547854098302,0.525265297877247,0.527436513434051
+0.499293045606464,0.733599063009024,25,1,0.249293545390979,0.366280910423824,0.538167585247717,0.233600132755117,0.88739006679064,0.888186376514393
+0.553942533675581,0.548312899889424,24,1,0.306852330614922,0.303733837011753,0.30064703618515,0.0724150069741539,0.779422457207946,0.706833997094728
+0.661088703200221,0.98143746308051,24,2,0.43703827349895,0.64881721974001,0.963219493937908,0.507672730364875,1.1833248782295,1.03830648704340
+0.492181566543877,0.376017479225993,23,1,0.242242694445585,0.185068871973329,0.141389144683470,0.124228794404457,0.619380205632255,0.63187712891139
+0.991064163157716,0.216620326042175,21,2,0.982208175495505,0.21468464215194,0.0469243656546183,0.566963889458783,1.01446170018888,0.21680455446021
+0.601602173643187,0.343355831922963,24,1,0.361925175332207,0.206563614817919,0.117893227315510,0.186709392055052,0.692689254029335,0.52594111396747
+0.0397100185509771,0.0602901463862509,25,2,0.00157688557331895,0.00239412283143915,0.00363490175127556,0.636562347604197,0.0721927096360464,0.962180726382856
+0.158290433697402,0.630195834673941,23,2,0.0250558614001118,0.0997539719848347,0.397146790040385,0.365672507948237,0.649771230080632,1.05148551299849
+0.967184047214687,0.497705311980098,25,2,0.935444981186582,0.48137263796116,0.247710577573207,0.467189682639721,1.08772954302059,0.498785990511377
+0.538070349488407,0.0130743277259171,24,2,0.289519700998577,0.00703490808881019,0.000170938045484685,0.488411672495383,0.538229169633216,0.462114639529248
+0.758642012253404,0.673675778554752,25,2,0.575537702755893,0.511078748249156,0.453839054611352,0.311542880770993,1.01458206044028,0.715606548922268
+0.986405614530668,0.981674374546856,21,2,0.972996036377624,0.9683291146939,0.96368457764196,0.684544100071034,1.39164672744903,0.981768498658543
+0.51937106740661,0.462004136526957,23,1,0.269746305659081,0.239951581534275,0.213447822168019,0.0426488439882434,0.695121664046734,0.666672328069706
+0.534244359936565,0.692785677267238,21,1,0.28541703612403,0.370116840724856,0.479951994626626,0.195803456422130,0.87485371963012,0.83479357381183
+0.0795328004751354,0.536029864801094,22,2,0.00632546635141770,0.0426319562859392,0.287328015958679,0.422008076977050,0.541898036820671,1.06517035321108
+0.330987347057089,0.804738595616072,23,2,0.10955262391189,0.266358292837412,0.647604207274128,0.348469350894533,0.870147591610767,1.04650950166343
+0.9804020607844,0.74571731640026,25,2,0.961188200790297,0.731102793761427,0.556094315979205,0.539595348001485,1.23178022259229,0.745974795285138
+0.362560331821442,0.805498170899227,21,2,0.131449994210474,0.292041684122788,0.648827303322001,0.334990738397057,0.883333061496328,1.02720817456326
+0.47635925677605,0.961423690896481,21,2,0.226918141516230,0.457983074842334,0.924335513417013,0.462028903057712,1.07296488988841,1.09477629741475
+0.850710266502574,0.635807712096721,24,2,0.723707957532881,0.540888148202193,0.404251446761667,0.376086992190972,1.06205433208219,0.65309943445803
+0.136131341336295,0.714137809583917,25,2,0.0185317420940189,0.0972165379176223,0.509992811077315,0.422203034393551,0.726996941651981,1.12083088398685
+0.930458213202655,0.865616530412808,24,2,0.865752486516278,0.805420010206583,0.749291977723908,0.564774043865972,1.27084399681479,0.868405457050378
+0.374636142514646,0.197784703457728,21,2,0.140352239278254,0.0740972983518064,0.0391187889218614,0.327185241457712,0.423640210792266,0.655895375171089
+0.482126326300204,0.841961156809703,22,1,0.232445794511731,0.405931639420132,0.708898589576332,0.342427950053959,0.970229036922758,0.988479504839456
+0.660344187868759,0.746531683253124,24,2,0.436054446452051,0.492967858096082,0.557309554100743,0.294088642131774,0.996676477375078,0.82016804669243
+0.0772640188224614,0.437956433976069,22,2,0.00596972860459766,0.0338382741581451,0.191805838061035,0.427264688298837,0.444719649515999,1.02139489377063
+0.998469967395067,0.464829172473401,25,2,0.996942275789907,0.464117968683793,0.216066159582307,0.499709210945471,1.10136662168971,0.464831690595724

Added: mahout/trunk/examples/src/main/resources/donut.csv
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/resources/donut.csv?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/resources/donut.csv (added)
+++ mahout/trunk/examples/src/main/resources/donut.csv Sat Jul 24 23:38:28 2010
@@ -0,0 +1,41 @@
+"x","y","shape","color","k","k0","xx","xy","yy","a","b","c","bias"
+0.923307513352484,0.0135197141207755,21,2,4,8,0.852496764213146,0.0124828536260896,0.000182782669907495,0.923406490600458,0.0778750292332978,0.644866125183976,1
+0.711011884035543,0.909141522599384,22,2,3,9,0.505537899239772,0.64641042683833,0.826538308114327,1.15415605849213,0.953966686673604,0.46035073663368,1
+0.75118898646906,0.836567111080512,23,2,3,9,0.564284893392414,0.62842000028592,0.699844531341594,1.12433510339845,0.872783737128441,0.419968245447719,1
+0.308209649519995,0.418023289414123,24,1,5,1,0.094993188057238,0.128838811521522,0.174743470492603,0.519361780024138,0.808280495564412,0.208575453051705,1
+0.849057961953804,0.500220163026825,25,1,5,2,0.720899422757147,0.424715912147755,0.250220211498583,0.985454024425153,0.52249756970547,0.349058031386046,1
+0.0738831346388906,0.486534863477573,21,2,6,1,0.00545871758406844,0.0359467208248278,0.236716173379140,0.492112681164801,1.04613986717142,0.42632955896436,1
+0.612888508243486,0.0204555552918464,22,2,4,10,0.375632323536926,0.0125369747681119,0.000418429742297785,0.613229772009826,0.387651566219268,0.492652707029903,1
+0.207169560948387,0.932857288978994,23,2,1,4,0.0429192269835473,0.193259634985281,0.870222721601238,0.955584610897845,1.22425602987611,0.522604151014326,1
+0.309267645236105,0.506309477845207,24,1,5,1,0.0956464763898851,0.156585139973909,0.256349287355886,0.593292308854389,0.856423069092351,0.190836685845410,1
+0.78758287569508,0.171928803203627,25,2,4,10,0.620286786088131,0.135408181241926,0.0295595133710317,0.806130448165285,0.273277419610556,0.436273561610666,1
+0.930236018029973,0.0790199618786573,21,2,4,8,0.86533904924026,0.0735072146828825,0.00624415437530446,0.93358620577618,0.105409523078414,0.601936228937031,1
+0.238834470743313,0.623727766098455,22,1,5,1,0.0570419044152386,0.148967690904034,0.389036326202168,0.667890882268509,0.984077887735915,0.288991338582386,1
+0.83537525916472,0.802311758277938,23,2,3,7,0.697851823624524,0.670231393002335,0.643704157471036,1.15825557675997,0.819027144096042,0.451518508649315,1
+0.656760312616825,0.320640653371811,24,1,5,3,0.43133410822855,0.210584055746134,0.102810428594702,0.730851925374252,0.469706197095164,0.238209090579297,1
+0.180789119331166,0.114329558331519,25,2,2,5,0.0326847056685386,0.0206695401642766,0.0130712479082803,0.213906413126907,0.82715035810576,0.500636870310341,1
+0.990028728265315,0.061085847672075,21,2,4,8,0.980156882790638,0.0604767440857932,0.00373148078581595,0.991911469626425,0.06189432159595,0.657855445853466,1
+0.751934139290825,0.972332585137337,22,2,3,9,0.565404949831033,0.731130065509666,0.945430656119858,1.22916052895905,1.00347761677540,0.535321288127727,1
+0.136412925552577,0.552212274167687,23,2,6,1,0.0186084862578129,0.0753288918452558,0.304938395741448,0.5688118159807,1.02504684326820,0.3673168690368,1
+0.5729476721026,0.0981996888294816,24,2,4,10,0.328269034967789,0.0562632831160512,0.0096431788862070,0.581302170866406,0.43819729534628,0.408368525870829,1
+0.446335297077894,0.339370004367083,25,1,5,3,0.199215197417612,0.151472811718508,0.115171999864114,0.560702414192882,0.649397107420365,0.169357302283512,1
+0.922843366628513,0.912627586396411,21,2,3,7,0.851639879330248,0.842212314308118,0.832889111451739,1.29789405992245,0.915883320912091,0.590811338548155,1
+0.166969822719693,0.398156099021435,22,2,6,1,0.0278789216990458,0.0664800532683736,0.158528279187967,0.431749002184154,0.923291695753637,0.348254618269284,1
+0.350683249300346,0.84422400011681,23,2,1,6,0.122978741339848,0.296055215498298,0.712714162373228,0.914162405545687,1.06504760696993,0.375214144584023,1
+0.47748578293249,0.792779305484146,24,1,5,6,0.227992672902653,0.378540847371773,0.628499027203925,0.9254683679665,0.949484141121692,0.29364368150863,1
+0.384564548265189,0.153326370986179,25,2,2,5,0.147889891782409,0.0589638865954405,0.0235089760397912,0.414003463538894,0.634247405427742,0.365387395199715,1
+0.563622857443988,0.467359990812838,21,1,5,3,0.317670725433326,0.263414773476928,0.218425361012576,0.73218582781006,0.639414084578942,0.071506910079209,1
+0.343304847599939,0.854578266385943,22,2,1,6,0.117858218385617,0.293380861503846,0.730304013379203,0.920957236664559,1.07775346743350,0.387658506651072,1
+0.666085948701948,0.710089378990233,23,1,5,2,0.443670491058174,0.472980557667886,0.504226926154735,0.973600234805286,0.784681795257806,0.267809801016930,1
+0.190568120684475,0.0772022884339094,24,2,2,5,0.0363162086212125,0.0147122950193909,0.00596019333943254,0.205612261211838,0.813105258002736,0.523933195018469,1
+0.353534662164748,0.427994541125372,25,1,5,1,0.124986757351942,0.151310905505115,0.183179327233118,0.555127088678854,0.775304301713569,0.163208092002022,1
+0.127048352966085,0.927507144864649,21,2,1,4,0.0161412839913949,0.117838255119330,0.860269503774972,0.936168140755905,1.27370093893119,0.567322915045421,1
+0.960906301159412,0.891004979610443,22,2,3,7,0.923340919607862,0.856172299272088,0.793889873690606,1.31043152942016,0.891862204031343,0.604416671286136,1
+0.306814440060407,0.902291874401271,23,2,1,6,0.094135100629581,0.276836176215481,0.81413062661056,0.953029761990747,1.13782109627099,0.446272800849954,1
+0.087350245565176,0.671402548439801,24,2,6,4,0.00763006540029655,0.0586471774793016,0.450781382051459,0.677060889028273,1.13300968942079,0.446831795474291,1
+0.27015240653418,0.371201378758997,25,1,5,1,0.0729823227562089,0.100280945780549,0.137790463592580,0.459099974241765,0.81882108746687,0.263474858488646,1
+0.871842501685023,0.569787061074749,21,2,3,2,0.7601093477444,0.496764576755166,0.324657294968199,1.04152131169391,0.584021951079369,0.378334613738721,1
+0.686449621338397,0.169308491749689,22,2,4,10,0.471213082635629,0.116221750050949,0.0286653653785545,0.707020825728764,0.356341416814533,0.379631841296403,1
+0.67132937326096,0.571220482233912,23,1,5,2,0.450683127402953,0.383477088331915,0.326292839323543,0.881462402332905,0.659027480614106,0.185542747720368,1
+0.548616112209857,0.405350996181369,24,1,5,3,0.300979638576258,0.222382087605415,0.164309430105228,0.682121007359754,0.606676886210257,0.106404700508298,1
+0.677980388281867,0.993355110753328,25,2,3,9,0.459657406894831,0.673475283690318,0.986754376059756,1.20266860895036,1.04424662144096,0.524477152905055,1

Added: mahout/trunk/examples/src/main/resources/test-data.csv
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/resources/test-data.csv?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/resources/test-data.csv (added)
+++ mahout/trunk/examples/src/main/resources/test-data.csv Sat Jul 24 23:38:28 2010
@@ -0,0 +1,61 @@
+"V1","V2","V3","V4","V5","V6","V7","V8","y"
+1,-0.212887381184450,-0.955959589855826,-0.00326541907490505,0.0560086232868742,0.091264583618544,0.0172194710825328,-0.0237399208336878,1
+1,3.14702017427074,2.12881054220556,-0.00566925018709358,-0.055626039510634,-0.0630510476335515,-0.00155145331201058,0.108559859662683,0
+1,-2.16541417186635,-2.71847685293678,-0.00833554984263851,0.0433655514274994,-0.102555485096075,-0.156155728366877,-0.0241458595902909,1
+1,-4.33686585982661,-2.6857484867589,-0.0115524101901378,0.122387581992154,0.081766215557828,-0.0206167352421607,-0.0424490760296281,1
+1,2.34100936064648,2.10958510331364,-0.0129315842415535,0.173866353524092,-0.0299915285951044,0.108136400830407,-0.0063355720943443,0
+1,1.30317270786224,3.37038662087804,-0.0230504278644102,-0.131884713919903,0.086455020204179,0.17337860146005,-0.0524355492943794,0
+1,1.94943481762617,3.54806480367192,-0.029538920288902,-0.0720379027720258,0.214306548234308,-0.082665692089578,0.226607475768828,0
+1,3.14635496849369,1.76134258264267,-0.0318247859223975,-0.187198080297378,-0.08576487890296,0.153638925055934,-0.0691201521844938,0
+1,-1.26105438936697,-1.95583819596755,-0.0367826492102569,-0.0936093811581598,-0.0317225362744449,-0.0840334569992295,-0.0627566339884115,1
+1,2.40442001058194,3.23077413487565,-0.0452264569747572,0.0371989606630366,-0.17352653795031,0.102543062447842,-0.0551882772900301,0
+1,-2.20940227045733,-0.175769402031962,-0.0465958462590872,0.130789407148096,-0.140283147466875,0.0708851428212228,0.0605244763586474,1
+1,-1.64710385829030,-2.57691366099069,-0.0553070134425288,-0.0349011715152424,-0.0826092377112715,0.106766133325393,-0.0585587032435851,1
+1,-2.6523724984616,-4.16903830585265,-0.0568310036349303,-0.0291979248790545,-0.255996825268056,0.0401827924643623,0.0179311252387879,1
+1,2.34337447158977,0.28996735916551,-0.0625800583342644,0.0899232083837452,0.0255207970332586,-0.0343458209061299,0.0755898049986344,0
+1,3.67556867120403,1.36097809464341,-0.0956707962851342,0.0537771695881714,-0.0373171704803031,0.0463473815328367,-0.228499359561800,0
+1,1.96533061882493,2.92646586187099,-0.103334098736041,-0.0194013528907574,0.0253359438067293,0.00748464018133427,-0.239745502177878,0
+1,-1.95041601303593,-0.860607985906108,-0.103721968898869,-0.00972933741506002,0.0227857854969761,-0.0287381002832544,-0.130156656165122,1
+1,-1.51543545229533,-1.35683836829949,-0.106483722717291,0.103877046729912,0.00840497101030744,0.0258430051020969,0.168907472637671,1
+1,1.45074382041585,1.88231080047069,-0.107681637419817,-0.00626324733854461,-0.144385489192821,0.00088239451623517,-0.00299885969569744,0
+1,3.87956616310254,4.31276421460554,-0.129963535661731,-0.0640782960295875,-0.0324909886960640,0.0428280701443882,0.0329254937199428,0
+1,-2.88187391546093,-3.16731558128991,-0.136390769151814,-0.155408895734766,0.105626409419800,-0.0918345772196075,0.197828194781600,1
+1,-2.65024496288248,-1.81147577507541,-0.145438998990911,0.0691687502404964,0.0749439097959056,-0.0674149410216342,0.123896965825847,1
+1,-1.37426198993006,-2.08894064826135,-0.153236566384176,0.0213513951854753,-0.134553043562400,0.00287304090325258,0.0122158739075685,1
+1,1.65698424179346,2.49004336804714,-0.153862461770005,0.105220938080375,-0.0946233303225818,-0.122426312548592,-0.00538234276442917,0
+1,2.93315586503758,2.75229115279104,-0.168877592929163,-0.0349207806558679,0.0189964813847077,0.202397029441612,0.0426299706123943,0
+1,-3.84306960373604,-2.35606387141237,-0.179511886850707,-0.0916819865200809,0.0265829433229566,0.101658708455140,-0.0855390303406673,1
+1,2.28101644492271,1.37963780647481,-0.180898801743387,-0.0789829066843624,-0.0779025366072777,0.0442621459868237,-0.136195159617836,0
+1,1.70008372335953,2.71018350574622,-0.188985514267118,-0.195856534813112,-0.106263419324547,-0.0311178988395261,-0.121173036989233,0
+1,-2.05613043162767,-1.73770126734937,0.00630625444849072,-0.134595964087825,0.0708994966210059,0.0739139562742148,-0.00416084523004362,1
+1,2.39375626983328,3.2468518382106,0.00951905535238045,-0.140380515724865,0.0630970962358967,0.00183192220061040,-0.0773483294293499,0
+1,4.26863682432937,3.49421800345979,0.0109175198048448,-0.109995560295421,-0.111585866731122,0.154763193427948,-0.0186987535307691,0
+1,1.54495296452702,3.17243560853872,0.0117478311845783,0.115838636637105,-0.1715332868224,0.0927292648278796,-0.0885962242970987,0
+1,2.16883227993245,1.63879588167162,0.0158863105366749,-0.00488771308802354,0.0280782748001184,0.131946735985038,0.066416828384239,0
+1,1.86427271422921,3.32026821853873,0.0162473257475520,0.0355005599857545,-0.0988825269654524,0.0527023072810735,0.100841323212596,0
+1,-3.03828333997027,-1.43214405751321,0.0247204684728272,0.146197859364444,0.0141171187314724,-0.201738256450160,0.044002672456105,1
+1,2.08595761680696,0.225336429607513,0.0335964287149376,0.0576493862055925,0.121452048491972,0.0640240734436852,0.224720096669846,0
+1,-1.85256114614442,-2.22817393781734,0.0346230650580488,0.160185441442375,0.0114059982858295,0.00496408500928602,-0.094156048483371,1
+1,2.33572915427688,1.03334367238243,0.0357824515834720,-0.172284120406131,0.0329286256184980,-0.101030665525296,-0.00238851979619332,0
+1,-2.00334039609229,-2.98875026257892,0.0375804284421083,0.142856636546252,-0.0862220203147005,-0.0441603903572752,0.0147126239348866,1
+1,2.38346139581192,1.21051372282823,0.0405425233313353,-0.145245065311593,-0.0216697981922324,-0.0128934036902430,-0.0325085994141851,0
+1,-1.15629168023471,-1.37784639006639,0.0429948703549178,-0.00491267793152886,0.0263522850749959,-0.0442602193050815,0.0582704866256344,1
+1,2.13230915550664,1.32833684701498,0.0434112538719301,-0.0296522957829338,0.00247091583877657,-0.123872403365319,-0.136549696313901,0
+1,-1.88291252343724,-1.99980946454726,0.0472833199907535,-0.0365284873908706,-0.0209054390489622,-0.0891896486647233,0.0542966824787834,1
+1,-1.34787394136153,-2.57763619051754,0.0493154843443071,0.0384664637019124,-0.00780509859650452,-0.118550134827935,0.00573215142098708,1
+1,-1.81748193199251,-2.72113041015796,0.0551479875680516,-0.255723061179778,-0.217672946803948,0.145106553357089,0.0632886151091758,1
+1,-3.13049595715861,-0.0285946551309455,0.0724437318718333,-0.0360911974267016,-0.121364676014540,0.038351368519738,-0.0125375424386282,1
+1,-2.3836883021805,-1.40162632998805,0.0746620557343183,0.069222624188286,0.04657285528431,0.0932835769596473,0.00836816351062604,1
+1,-2.43800450243598,-0.965440038635416,0.0763675021411913,-0.122575769653323,0.045866930905471,-0.0493852614669876,0.128116802512532,1
+1,1.09024638837653,2.21814920469686,0.0769910502309598,-0.270152593833931,-0.252735856082821,0.0661674666715274,-0.000429289775969046,0
+1,3.17642151475607,1.18015379683312,0.0776648965451875,-0.117234850817615,0.0759455286430382,0.119280079276134,0.117056969569811,0
+1,-3.5501372839931,-4.02435741321994,0.0833451415432366,-0.0185864612285970,0.0553371588028254,0.0269699189958747,-0.0930023774668385,1
+1,-2.85922019599943,-2.07644295605507,0.0903467736346066,0.124804691516462,0.0673015037344841,0.0234043567104492,0.0866115903248345,1
+1,0.513249476607372,5.0165612245778,0.0934321220365115,-0.0387550539552360,0.070129320868753,0.0635055975927393,-0.00773489793089484,0
+1,1.30094323285406,2.74698316868320,0.094239413405751,-0.105600040230387,-0.0134676903839459,0.00834379403909127,0.0978349326557826,0
+1,1.62511731278249,3.01296963021698,0.104352029985773,-0.0065839083200722,0.068460830526483,-0.1202220553,0.121998460927858,0
+1,1.82917662184333,2.89388269168932,0.110781239485760,-0.262387884050666,-0.00517657837760664,-0.0224028641246511,-0.108606003593092,0
+1,-3.17279743572930,-2.86698187406046,0.110873139279243,-0.093614374710967,0.0925974010859032,-0.00747619041107016,-0.066394213442664,1
+1,-3.20104938765970,-1.68043245593876,0.123227179211642,-0.00179275501686146,-0.175893752209014,-0.0835732816974749,0.0560957582079696,1
+1,-1.89923900052239,-2.92427973445236,0.147975477003611,0.00819675018680998,0.00470753628896422,-0.0122227288860826,0.209903875101594,1
+1,0.148491843864120,-1.54734877494689,0.162479731968606,0.112962938668545,-0.0100535803565242,0.0422099301034027,0.0752974779385111,1

Added: mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java?rev=978948&view=auto
==============================================================================
--- mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java (added)
+++ mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/LogisticModelParametersTest.java Sat Jul 24 23:38:28 2010
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.StringWriter;
+
+import static org.junit.Assert.assertEquals;
+
+
+public class LogisticModelParametersTest {
+  @Test
+  public void testSaveTo() throws IOException {
+    LogisticModelParameters lmp = new LogisticModelParameters();
+    lmp.setTargetVariable("target");
+    lmp.setMaxTargetCategories(3);
+
+    lmp.setTypeMap(ImmutableList.of("x", "y", "z"), ImmutableList.of("n", "word", "numeric"));
+    lmp.setUseBias(true);
+
+    lmp.setLambda(123.4);
+    lmp.setLearningRate(5.2);
+
+    lmp.setNumFeatures(214);
+
+    lmp.getCsvRecordFactory().firstLine("x,target,q,z,r,y");
+    Vector v = new DenseVector(20);
+    assertEquals(0, lmp.getCsvRecordFactory().processLine("3,t_1,foo,5,r,cat", v));
+    assertEquals(1, lmp.getCsvRecordFactory().processLine("3,t_2,foo,5,r,dog", v));
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_3,foo,5,r,pig", v));
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_4,foo,5,r,pig", v));
+
+    assertEquals(3, lmp.getMaxTargetCategories());
+    assertEquals("[t_1, t_2, t_3]", lmp.getCsvRecordFactory().getTargetCategories().toString());
+
+    StringWriter s = new StringWriter();
+    lmp.saveTo(s);
+    s.close();
+    assertEquals("{\"targetVariable\":\"target\",\"typeMap\":{\"z\":\"numeric\",\"y\":\"word\",\"x\":\"n\"},\n" +
+            "  \"numFeatures\":214,\"useBias\":true,\"maxTargetCategories\":3,\n" +
+            "  \"targetCategories\":[\"t_1\",\"t_2\",\"t_3\"],\"lambda\":123.4,\"learningRate\":5.2}", s.toString().trim());
+  }
+
+  @Test
+  public void testSaveWithRegression() throws IOException {
+    LogisticModelParameters lmp = new LogisticModelParameters();
+    lmp.setTargetVariable("target");
+    lmp.setMaxTargetCategories(3);
+
+    lmp.setTypeMap(ImmutableList.of("x", "y", "z"), ImmutableList.of("n", "word", "numeric"));
+    lmp.setUseBias(true);
+
+    lmp.setLambda(123.4);
+    lmp.setLearningRate(5.2);
+
+    lmp.setNumFeatures(214);
+
+    OnlineLogisticRegression lr = lmp.createRegression();
+    lr.getBeta().set(0, 4, 5.0);
+    lr.getBeta().set(1, 3, 7.0);
+    lmp.getCsvRecordFactory().firstLine("x,target,q,z,r,y");
+    Vector v = new DenseVector(20);
+    assertEquals(0, lmp.getCsvRecordFactory().processLine("3,t_1,foo,5,r,cat", v));
+    assertEquals(1, lmp.getCsvRecordFactory().processLine("3,t_2,foo,5,r,dog", v));
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_3,foo,5,r,pig", v));
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_4,foo,5,r,pig", v));
+
+    assertEquals(3, lmp.getMaxTargetCategories());
+    assertEquals("[t_1, t_2, t_3]", lmp.getCsvRecordFactory().getTargetCategories().toString());
+
+    StringWriter s = new StringWriter();
+    lmp.saveTo(s);
+    s.close();
+    assertEquals("{\"targetVariable\":\"target\",\"typeMap\":{\"z\":\"numeric\",\"y\":\"word\",\"x\":\"n\"},\n" +
+            "  \"numFeatures\":214,\"useBias\":true,\"maxTargetCategories\":3,\n" +
+            "  \"targetCategories\":[\"t_1\",\"t_2\",\"t_3\"],\"lambda\":123.4,\"learningRate\":5.2,\n" +
+            "  \"lr\":{\"beta\":{\"rows\":2,\"cols\":214,\"data\":[[0.0,0.0,0.0,0.0,5.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],[\n" +
+            "          0.0,0.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+            "          0.0,0.0,0.0,0.0,0.0,0.0]]},\"numCategories\":3,\"step\":1,\"mu_0\":5.2,\n" +
+            "    \"decayFactor\":0.999,\"stepOffset\":10,\"decayExponent\":-0.5,\"lambda\":\n" +
+            "    123.4,\"sealed\":true}}", s.toString().trim());
+  }
+
+  @Test
+  public void testLoadFrom() {
+    LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new StringReader(
+            "{\"targetVariable\":\"target\",\"typeMap\":{\"z\":\"numeric\",\"y\":\"word\",\"x\":\"n\"},\n" +
+                    "  \"numFeatures\":214,\"useBias\":true,\"maxTargetCategories\":3,\n" +
+                    "  \"targetCategories\":[\"t_1\",\"t_2\",\"t_3\"],\"lambda\":123.4,\"learningRate\":5.2,\n" +
+                    "  \"lr\":{\"beta\":{\"rows\":2,\"cols\":214,\"data\":[[0.0,0.0,0.0,0.0,5.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],[\n" +
+                    "          0.0,0.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,\n" +
+                    "          0.0,0.0,0.0,0.0,0.0,0.0]]},\"numCategories\":3,\"step\":1,\"mu_0\":5.2,\n" +
+                    "    \"decayFactor\":0.999,\"stepOffset\":10,\"decayExponent\":-0.5,\"lambda\":\n" +
+                    "    123.4,\"sealed\":true}}"));
+
+    assertEquals(5.0, lmp.createRegression().getBeta().get(0, 4), 0);
+    assertEquals(7.0, lmp.createRegression().getBeta().get(1, 3), 0);
+
+    assertEquals(123.4, lmp.getLambda(), 0.0);
+    assertEquals(true, lmp.useBias());
+    assertEquals(5.2, lmp.getLearningRate(), 0.0);
+    assertEquals(214, lmp.getNumFeatures());
+
+    lmp.getCsvRecordFactory().firstLine("x,target,q,z,r,y");
+    Vector v = new DenseVector(20);
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_3,foo,5,r,pig", v));
+    assertEquals(1, lmp.getCsvRecordFactory().processLine("3,t_2,foo,5,r,dog", v));
+    assertEquals(2, lmp.getCsvRecordFactory().processLine("3,t_4,foo,5,r,pig", v));
+    assertEquals(0, lmp.getCsvRecordFactory().processLine("3,t_1,foo,5,r,cat", v));
+
+    assertEquals(3, lmp.getMaxTargetCategories());
+    assertEquals("[t_1, t_2, t_3]", lmp.getCsvRecordFactory().getTargetCategories().toString());
+
+    assertEquals(214, lmp.createRegression().getBeta().numCols());
+    assertEquals(2, lmp.createRegression().getBeta().numRows());
+  }
+}

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java?rev=978948&r1=978947&r2=978948&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/NegativeBinomial.java Sat Jul 24 23:38:28 2010
@@ -50,7 +50,8 @@ public class NegativeBinomial extends Ab
    */
   public NegativeBinomial(int n, double p, RandomEngine randomGenerator) {
     setRandomGenerator(randomGenerator);
-    setNandP(n, p);
+    this.n = n;
+    this.p = p;
     this.gamma = new Gamma(n, 1, randomGenerator);
     this.poisson = new Poisson(0.0, randomGenerator);
   }