You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/01 02:57:48 UTC
svn commit: r991409 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/classifier/sgd/
test/java/org/apache/mahout/classifier/sgd/
Author: tdunning
Date: Wed Sep 1 00:57:47 2010
New Revision: 991409
URL: http://svn.apache.org/viewvc?rev=991409&view=rev
Log:
MAHOUT-494 - Added serialization code based on GSON. Added tests for same.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=991409&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java Wed Sep 1 00:57:47 2010
@@ -0,0 +1,351 @@
+package org.apache.mahout.classifier.sgd;
+
+import com.google.gson.*;
+import com.google.gson.reflect.TypeToken;
+import org.apache.mahout.ep.EvolutionaryProcess;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.Reader;
+import java.lang.reflect.Type;
+import java.util.List;
+
+/**
+ * Provides the ability to store SGD model-related objects as JSON.
+ */
+public class ModelSerializer {
+ // thread-local singleton json (de)serializer
+ private static final ThreadLocal<Gson> GSON;
+ static {
+ final GsonBuilder gb = new GsonBuilder();
+ gb.registerTypeAdapter(AdaptiveLogisticRegression.class, new AdaptiveLogisticRegressionTypeAdapter());
+ gb.registerTypeAdapter(Mapping.class, new MappingTypeAdapter());
+ gb.registerTypeAdapter(PriorFunction.class, new PriorTypeAdapter());
+ gb.registerTypeAdapter(CrossFoldLearner.class, new CrossFoldLearnerTypeAdapter());
+ gb.registerTypeAdapter(Vector.class, new VectorTypeAdapter());
+ gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
+ gb.registerTypeAdapter(EvolutionaryProcess.class, new EvolutionaryProcessTypeAdapter());
+ gb.registerTypeAdapter(State.class, new StateTypeAdapter());
+ GSON = new ThreadLocal<Gson>() {
+ @Override
+ protected Gson initialValue() {
+ return gb.create();
+ }
+ };
+ }
+
+ // static class ... don't instantiate
+ private ModelSerializer() {
+ }
+
+ public static Gson gson() {
+ return GSON.get();
+ }
+
+ public static void writeJson(String path, AdaptiveLogisticRegression model) throws IOException {
+ FileWriter out = new FileWriter(path);
+ out.write(gson().toJson(model));
+ out.close();
+ }
+
+ /**
+ * Reads a model in JSON format.
+ *
+ * @param in Where to read the model from.
+ * @param clazz
+ * @return The LogisticModelParameters object that we read.
+ */
+ public static AdaptiveLogisticRegression loadJsonFrom(Reader in, Class<AdaptiveLogisticRegression> clazz) {
+ return gson().fromJson(in, clazz);
+ }
+
+ private static class MappingTypeAdapter implements JsonDeserializer<Mapping>, JsonSerializer<Mapping> {
+ @Override
+ public Mapping deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+ JsonObject x = jsonElement.getAsJsonObject();
+ try {
+ return jsonDeserializationContext.deserialize(x.get("value"), (Class) Class.forName(x.get("class").getAsString()));
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException("Can't understand serialized data, found bad type: " + x.get("class").getAsString());
+ }
+ }
+
+ @Override
+ public JsonElement serialize(Mapping mapping, Type type, JsonSerializationContext jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ r.add("class", new JsonPrimitive(mapping.getClass().getName()));
+ r.add("value", jsonSerializationContext.serialize(mapping));
+ return r;
+ }
+ }
+
+ private static class PriorTypeAdapter implements JsonDeserializer<PriorFunction>, JsonSerializer<PriorFunction> {
+ @Override
+ public PriorFunction deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+ JsonObject x = jsonElement.getAsJsonObject();
+ try {
+ return jsonDeserializationContext.deserialize(x.get("value"), (Class) Class.forName(x.get("class").getAsString()));
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException("Can't understand serialized data, found bad type: " + x.get("class").getAsString());
+ }
+ }
+
+ @Override
+ public JsonElement serialize(PriorFunction priorFunction, Type type, JsonSerializationContext jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ r.add("class", new JsonPrimitive(priorFunction.getClass().getName()));
+ r.add("value", jsonSerializationContext.serialize(priorFunction));
+ return r;
+ }
+ }
+
+ private static class CrossFoldLearnerTypeAdapter implements JsonDeserializer<CrossFoldLearner> {
+ @Override
+ public CrossFoldLearner deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+ CrossFoldLearner r = new CrossFoldLearner();
+ JsonObject x = jsonElement.getAsJsonObject();
+ r.setRecord(x.get("record").getAsInt());
+ r.setAuc(jsonDeserializationContext.<OnlineAuc>deserialize(x.get("auc"), OnlineAuc.class));
+ r.setLogLikelihood(x.get("logLikelihood").getAsDouble());
+
+ JsonArray models = x.get("models").getAsJsonArray();
+ for (JsonElement model : models) {
+ r.addModel(jsonDeserializationContext.<OnlineLogisticRegression>deserialize(model, OnlineLogisticRegression.class));
+ }
+
+ r.setParameters(asArray(x, "parameters"));
+ r.setNumFeatures(x.get("numFeatures").getAsInt());
+ r.setPrior(jsonDeserializationContext.<PriorFunction>deserialize(x.get("prior"), PriorFunction.class));
+ return r;
+ }
+ }
+
+ /**
+ * Tells GSON how to (de)serialize a Mahout matrix. We assume on deserialization that the matrix
+ * is dense.
+ */
+ private static class MatrixTypeAdapter
+ implements JsonDeserializer<Matrix>, JsonSerializer<Matrix>, InstanceCreator<Matrix> {
+ @Override
+ public JsonElement serialize(Matrix m, Type type, JsonSerializationContext jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ r.add("rows", new JsonPrimitive(m.numRows()));
+ r.add("cols", new JsonPrimitive(m.numCols()));
+ JsonArray v = new JsonArray();
+ for (int row = 0; row < m.numRows(); row++) {
+ JsonArray rowData = new JsonArray();
+ for (int col = 0; col < m.numCols(); col++) {
+ rowData.add(new JsonPrimitive(m.get(row, col)));
+ }
+ v.add(rowData);
+ }
+ r.add("data", v);
+ return r;
+ }
+
+ @Override
+ public Matrix deserialize(JsonElement x, Type type, JsonDeserializationContext jsonDeserializationContext) {
+ JsonObject data = x.getAsJsonObject();
+ Matrix r = new DenseMatrix(data.get("rows").getAsInt(), data.get("cols").getAsInt());
+ int i = 0;
+ for (JsonElement row : data.get("data").getAsJsonArray()) {
+ int j = 0;
+ for (JsonElement element : row.getAsJsonArray()) {
+ r.set(i, j, element.getAsDouble());
+ j++;
+ }
+ i++;
+ }
+ return r;
+ }
+
+ @Override
+ public Matrix createInstance(Type type) {
+ return new DenseMatrix();
+ }
+ }
+
+
+ /**
+ * Tells GSON how to (de)serialize a Mahout matrix. We assume on deserialization that the
+ * matrix is dense.
+ */
+ private static class VectorTypeAdapter
+ implements JsonDeserializer<Vector>, JsonSerializer<Vector>, InstanceCreator<Vector> {
+ @Override
+ public JsonElement serialize(Vector m, Type type, JsonSerializationContext jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ JsonArray v = new JsonArray();
+ for (int i = 0; i < m.size(); i++) {
+ v.add(new JsonPrimitive(m.get(i)));
+ }
+ r.add("data", v);
+ return r;
+ }
+
+ @Override
+ public Vector deserialize(JsonElement x, Type type, JsonDeserializationContext jsonDeserializationContext) {
+ JsonArray data = x.getAsJsonObject().get("data").getAsJsonArray();
+ Vector r = new DenseVector(data.size());
+ int i = 0;
+ for (JsonElement v : data) {
+ r.set(i, v.getAsDouble());
+ i++;
+ }
+ return r;
+ }
+
+ @Override
+ public Vector createInstance(Type type) {
+ return new DenseVector();
+ }
+ }
+
+ private static class StateTypeAdapter implements JsonSerializer<State<AdaptiveLogisticRegression.Wrapper>>,
+ JsonDeserializer<State<AdaptiveLogisticRegression.Wrapper>> {
+ @Override
+ public State<AdaptiveLogisticRegression.Wrapper> deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+ JsonObject v = (JsonObject) jsonElement;
+ double[] params = asArray(v, "params");
+ double omni = v.get("omni").getAsDouble();
+ State<AdaptiveLogisticRegression.Wrapper> r = new State<AdaptiveLogisticRegression.Wrapper>(params, omni);
+
+ double[] step = asArray(v, "step");
+ r.setId(v.get("id").getAsInt());
+ r.setStep(step);
+ r.setValue(v.get("value").getAsDouble());
+
+ Type mapListType = new TypeToken<List<Mapping>>() {
+ }.getType();
+ r.setMaps(jsonDeserializationContext.<List<Mapping>>deserialize(v.get("maps"), mapListType));
+
+ r.setPayload(jsonDeserializationContext.<AdaptiveLogisticRegression.Wrapper>deserialize(v.get("payload"), AdaptiveLogisticRegression.Wrapper.class));
+ return r;
+ }
+
+ @Override
+ public JsonElement serialize(State<AdaptiveLogisticRegression.Wrapper> state, Type type, JsonSerializationContext jsonSerializationContext) {
+ JsonObject r = new JsonObject();
+ r.add("id", new JsonPrimitive(state.getId()));
+ JsonArray v = new JsonArray();
+ for (double x : state.getParams()) {
+ v.add(new JsonPrimitive(x));
+ }
+ r.add("params", v);
+
+ v = new JsonArray();
+ for (Mapping mapping : state.getMaps()) {
+ v.add(jsonSerializationContext.serialize(mapping, Mapping.class));
+ }
+ r.add("maps", v);
+ r.add("omni", new JsonPrimitive(state.getOmni()));
+ r.add("step", jsonSerializationContext.serialize(state.getStep()));
+ r.add("value", new JsonPrimitive(state.getValue()));
+ r.add("payload", jsonSerializationContext.serialize(state.getPayload()));
+
+ return r;
+ }
+ }
+
+ private static class AdaptiveLogisticRegressionTypeAdapter implements JsonSerializer<AdaptiveLogisticRegression>,
+ JsonDeserializer<AdaptiveLogisticRegression> {
+
+ @Override
+ public AdaptiveLogisticRegression deserialize(JsonElement element, Type type, JsonDeserializationContext jdc) throws JsonParseException {
+ JsonObject x = element.getAsJsonObject();
+ AdaptiveLogisticRegression r = new AdaptiveLogisticRegression(x.get("numCategories").getAsInt(), x.get("numFeatures").getAsInt(), jdc.<PriorFunction>deserialize(x.get("prior"), PriorFunction.class));
+ Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {
+ }.getType();
+ r.setEvaluationInterval(x.get("evaluationInterval").getAsInt());
+ r.setRecord(x.get("record").getAsInt());
+
+ Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() {
+ }.getType();
+ r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"), epType));
+ r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"), stateType));
+ r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"), stateType));
+
+ r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"), new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+ }.getType()));
+ return r;
+ }
+
+ @Override
+ public JsonElement serialize(AdaptiveLogisticRegression x, Type type, JsonSerializationContext jsc) {
+ JsonObject r = new JsonObject();
+ r.add("ep", jsc.serialize(x.getEp(), new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() {
+ }.getType()));
+ r.add("buffer", jsc.serialize(x.getBuffer(), new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+ }.getType()));
+ r.add("evaluationInterval", jsc.serialize(x.getEvaluationInterval()));
+ Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {
+ }.getType();
+ r.add("best", jsc.serialize(x.getBest(), stateType));
+ r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
+ r.add("numCategories", jsc.serialize(x.getNumCategories()));
+ PriorFunction prior = x.getPrior();
+ JsonElement pf = jsc.serialize(prior, PriorFunction.class);
+ r.add("prior", pf);
+ r.add("record", jsc.serialize(x.getRecord()));
+ r.add("seed", jsc.serialize(x.getSeed(), stateType));
+ return r;
+ }
+ }
+
+ private static class EvolutionaryProcessTypeAdapter implements InstanceCreator<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
+ JsonDeserializer<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
+ JsonSerializer<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>> {
+ private static final Type STATE_TYPE = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {
+ }.getType();
+
+ @Override
+ public EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> createInstance(Type type) {
+ return new EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>();
+ }
+
+ @Override
+ public EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
+ JsonObject x = (JsonObject) jsonElement;
+ int threadCount = x.get("threadCount").getAsInt();
+
+ EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> r = new EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>();
+ r.setThreadCount(threadCount);
+
+ for (JsonElement element : x.get("population").getAsJsonArray()) {
+ State<AdaptiveLogisticRegression.Wrapper> state = jsonDeserializationContext.deserialize(element, STATE_TYPE);
+ r.add(state);
+ }
+ return r;
+ }
+
+ @Override
+ public JsonElement serialize(EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> x, Type type, JsonSerializationContext jsc) {
+ JsonObject r = new JsonObject();
+ r.add("threadCount", new JsonPrimitive(x.getThreadCount()));
+ JsonArray v = new JsonArray();
+ for (State<AdaptiveLogisticRegression.Wrapper> state : x.getPopulation()) {
+ v.add(jsc.serialize(state, STATE_TYPE));
+ }
+ r.add("population", v);
+ return r;
+ }
+ }
+
+ public static double[] asArray(JsonObject v, String name) {
+ JsonArray x = v.get(name).getAsJsonArray();
+ double[] params = new double[x.size()];
+ int i = 0;
+ for (JsonElement element : x) {
+ params[i++] = element.getAsDouble();
+ }
+ return params;
+ }
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=991409&r1=991408&r2=991409&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java Wed Sep 1 00:57:47 2010
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.mahout.classifier.sgd;
import org.apache.mahout.common.RandomUtils;
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=991409&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java Wed Sep 1 00:57:47 2010
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.junit.Test;
+
+import java.io.StringReader;
+import java.lang.reflect.Type;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Aug 31, 2010 Time: 6:45:22 PM To change this
+ * template use File | Settings | File Templates.
+ */
+public class ModelSerializerTest {
+ @Test
+ public void testSoftLimitDeserialization() {
+ Mapping m = ModelSerializer.gson().fromJson(new StringReader("{\"min\":-18.420680743952367,\"max\":-2.3025850929940455,\"scale\":1.0}"), Mapping.SoftLimit.class);
+ assertTrue(m instanceof Mapping.SoftLimit);
+ assertEquals((-18.420680743952367 + -2.3025850929940455) / 2, m.apply(0), 1e-6);
+
+ String data = "{\"class\":\"org.apache.mahout.ep.Mapping$SoftLimit\",\"value\":{\"min\":-18.420680743952367,\"max\":-2.3025850929940455,\"scale\":1.0}}";
+ m = ModelSerializer.gson().fromJson(new StringReader(data), Mapping.class);
+ assertTrue(m instanceof Mapping.SoftLimit);
+ assertEquals((-18.420680743952367 + -2.3025850929940455) / 2, m.apply(0), 1e-6);
+ }
+
+ @Test
+ public void testMappingDeserialization() {
+ String data = "{\"class\":\"org.apache.mahout.ep.Mapping$LogLimit\",\"value\":{\"wrapped\":{\"class\":\"org.apache.mahout.ep.Mapping$SoftLimit\",\"value\":{\"min\":-18.420680743952367,\"max\":-2.3025850929940455,\"scale\":1.0}}}}";
+ Mapping m = ModelSerializer.gson().fromJson(new StringReader(data), Mapping.class);
+ assertTrue(m instanceof Mapping.LogLimit);
+ assertEquals(Math.sqrt(Math.exp(-18.420680743952367) * Math.exp(-2.3025850929940455)), m.apply(0), 1e-6);
+ }
+
+ @Test
+ public void onlineAucRoundtrip() {
+ RandomUtils.useTestSeed();
+ OnlineAuc auc1 = new OnlineAuc();
+ Random gen = new Random(2);
+ for (int i = 0; i < 10000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+ }
+ assertEquals(0.76, auc1.auc(), 0.04);
+
+ Gson gson = ModelSerializer.gson();
+ String s = gson.toJson(auc1);
+
+ OnlineAuc auc2 = gson.fromJson(s, OnlineAuc.class);
+
+ assertEquals(auc1.auc(), auc2.auc(), 0);
+
+ for (int i = 0; i < 1000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+
+ auc2.addSample(0, gen.nextGaussian());
+ auc2.addSample(1, gen.nextGaussian() + 1);
+ }
+
+ assertEquals(auc1.auc(), auc2.auc(), 0.01);
+ }
+
+ @Test
+ public void onlineLogisticRegressionRoundTrip() {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1());
+ train(olr, 100);
+ Gson gson = ModelSerializer.gson();
+ String s = gson.toJson(olr);
+ OnlineLogisticRegression olr2 = gson.fromJson(new StringReader(s), OnlineLogisticRegression.class);
+ assertEquals(0, olr.getBeta().minus(olr2.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1e-6);
+
+ train(olr, 100);
+ train(olr2, 100);
+
+ assertEquals(0, olr.getBeta().minus(olr2.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1e-6);
+ }
+
+ @Test
+ public void crossFoldLearnerRoundTrip() {
+ CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1());
+ train(learner, 100);
+ Gson gson = ModelSerializer.gson();
+ String s = gson.toJson(learner);
+ CrossFoldLearner olr2 = gson.fromJson(new StringReader(s), CrossFoldLearner.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, olr2.auc(), 1e-6);
+
+ train(learner, 100);
+ train(olr2, 100);
+
+ assertEquals(learner.auc(), olr2.auc(), 0.02);
+ double auc2 = learner.auc();
+ assertTrue(auc2 > auc1);
+ }
+
+ @Test
+ public void adaptiveLogisticRegressionRoundTrip() {
+ AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
+ learner.setInterval(200);
+ train(learner, 1000);
+ Gson gson = ModelSerializer.gson();
+ String s = gson.toJson(learner);
+ AdaptiveLogisticRegression olr2 = gson.fromJson(new StringReader(s), AdaptiveLogisticRegression.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, olr2.auc(), 1e-6);
+
+ train(learner, 1000);
+ train(olr2, 1000);
+
+ assertEquals(learner.auc(), olr2.auc(), 0.02);
+ double auc2 = learner.auc();
+ assertTrue(auc2 > auc1);
+ }
+
+ @Test
+ public void trainingExampleList() {
+ Random gen = new Random(1);
+ List<AdaptiveLogisticRegression.TrainingExample> x1 = Lists.newArrayList();
+ for (int i = 0; i < 10; i++) {
+ AdaptiveLogisticRegression.TrainingExample t = new AdaptiveLogisticRegression.TrainingExample(i, i % 3, randomVector(gen, 5));
+ x1.add(t);
+ }
+
+ Gson gson = ModelSerializer.gson();
+ Type listType = new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+ }.getType();
+ String s = gson.toJson(x1, listType);
+
+ List<AdaptiveLogisticRegression.TrainingExample> x2 = gson.fromJson(new StringReader(s), listType);
+
+ assertEquals(x1.size(), x2.size());
+ Iterator<AdaptiveLogisticRegression.TrainingExample> it = x2.iterator();
+ for (AdaptiveLogisticRegression.TrainingExample example : x1) {
+ AdaptiveLogisticRegression.TrainingExample example2 = it.next();
+ assertEquals(example.getKey(), example2.getKey());
+ assertEquals(0, example.getInstance().minus(example2.getInstance()).maxValue(), 1e-6);
+ assertEquals(example.getActual(), example2.getActual());
+ }
+ }
+
+ private void train(OnlineLearner olr, int n) {
+ Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
+ final Random gen = new Random(1);
+ for (int i = 0; i < n; i++) {
+ Vector x = randomVector(gen, 5);
+
+ int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
+ olr.train(target, x);
+ }
+ }
+
+ private Vector randomVector(final Random gen, int n) {
+ Vector x = new DenseVector(n);
+ x.assign(new UnaryFunction() {
+ @Override
+ public double apply(double v) {
+ return gen.nextGaussian();
+ }
+ });
+ return x;
+ }
+}