You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2014/01/08 21:39:09 UTC
svn commit: r1556629 - in /opennlp/addons/mahout-addon: ./ src/ src/main/
src/main/java/ src/main/java/opennlp/ src/main/java/opennlp/addons/
src/main/java/opennlp/addons/mahout/
Author: joern
Date: Wed Jan 8 20:39:08 2014
New Revision: 1556629
URL: http://svn.apache.org/r1556629
Log:
OPENNLP-574 Intial work to integrate Mahouts Logistic Regression Classifiers
Added:
opennlp/addons/mahout-addon/
opennlp/addons/mahout-addon/pom.xml
opennlp/addons/mahout-addon/src/
opennlp/addons/mahout-addon/src/main/
opennlp/addons/mahout-addon/src/main/java/
opennlp/addons/mahout-addon/src/main/java/SimpleTest.java
opennlp/addons/mahout-addon/src/main/java/opennlp/
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java
Added: opennlp/addons/mahout-addon/pom.xml
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/pom.xml?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/pom.xml (added)
+++ opennlp/addons/mahout-addon/pom.xml Wed Jan 8 20:39:08 2014
@@ -0,0 +1,86 @@
+<?xml version="1.0" encoding="UTF-8"?>
+
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.opennlp</groupId>
+ <artifactId>opennlp</artifactId>
+ <version>1.6.0-SNAPSHOT</version>
+ <relativePath>../opennlp/pom.xml</relativePath>
+ </parent>
+
+ <artifactId>mahout-addon</artifactId>
+ <packaging>jar</packaging>
+ <name>Apache OpenNLP Mahout Addon</name>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.opennlp</groupId>
+ <artifactId>opennlp-tools</artifactId>
+ <version>1.6.0-SNAPSHOT</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout-core</artifactId>
+ <version>0.8</version>
+ </dependency>
+
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <version>2.1</version>
+ <executions>
+ <execution>
+ <id>copy-dependencies</id>
+ <phase>package</phase>
+ <goals>
+ <goal>copy-dependencies</goal>
+ </goals>
+ <configuration>
+ <excludeScope>provided</excludeScope>
+ <stripVersion>true</stripVersion>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <configuration>
+ <skipTests>true</skipTests>
+ <argLine>-Xmx512m</argLine>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
Added: opennlp/addons/mahout-addon/src/main/java/SimpleTest.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/SimpleTest.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/SimpleTest.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/SimpleTest.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,51 @@
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+public class SimpleTest {
+
+ public static void main(String[] args) {
+
+ // Prepare data in vector format ...
+
+ // The basic idea is that you create a vector, typically a RandomAccessSparseVector,
+ // and then you use various feature encoders to progressively add features to that vector.
+ // The size of the vector should be large enough to avoid feature collisions as features are hashed.
+
+ // NOTE: Looks like we need to store the cardinality of the vector in the model ?!
+
+ StaticWordValueEncoder encoder = new StaticWordValueEncoder("word-encoder");
+
+ RandomAccessSparseVector vector1 = new RandomAccessSparseVector(3);
+ vector1.set(0, 1);
+ vector1.set(1, 0);
+ vector1.set(2, 1);
+
+// encoder.addToVector("f1", vector1);
+// encoder.addToVector("f", vector1);
+
+ RandomAccessSparseVector vector2 = new RandomAccessSparseVector(3);
+
+ vector2.set(0, 0);
+ vector2.set(1, 1);
+ vector2.set(2, 1);
+
+// encoder.addToVector("f2", vector2);
+// encoder.addToVector("f", vector2);
+
+ // do the training
+ PassiveAggressive pa = new PassiveAggressive(2, 3);
+ pa.train(0, vector1);
+ pa.train(1, vector2);
+
+ RandomAccessSparseVector vector = new RandomAccessSparseVector(pa.numFeatures());
+ vector.set(0, 1);
+ vector.set(1, 0);
+ vector.set(2, 1);
+
+ Vector result = pa.classifyFull(vector);
+
+ System.out.println(result);
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+abstract class AbstractOnlineLearnerTrainer extends AbstractEventTrainer {
+
+ protected final int iterations;
+
+ public AbstractOnlineLearnerTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+
+ // TODO: Extract parameters here, used by all implementations, e.g. learningRate
+
+ String iterationsValue = trainParams.get("Iterations");
+
+ if (iterationsValue != null) {
+ iterations = Integer.parseInt(iterationsValue);
+ }
+ else {
+ iterations = 20;
+ }
+ }
+
+ protected void trainOnlineLearner(DataIndexer indexer, org.apache.mahout.classifier.OnlineLearner pa) {
+ int cardinality = indexer.getPredLabels().length;
+ int outcomes[] = indexer.getOutcomeList();
+
+ for (int i = 0; i < indexer.getContexts().length; i++) {
+
+ Vector vector = new RandomAccessSparseVector(cardinality);
+
+ int features[] = indexer.getContexts()[i];
+
+ for (int fi = 0; fi < features.length; fi++) {
+ vector.set(features[fi], indexer.getNumTimesEventsSeen()[i]);
+ }
+
+ pa.train(outcomes[i], vector);
+ }
+ }
+
+ protected Map<String, Integer> createPrepMap(DataIndexer indexer) {
+ Map<String, Integer> predMap = new HashMap<String, Integer>();
+
+ String predLabels[] = indexer.getPredLabels();
+ for (int i = 0; i < predLabels.length; i++) {
+ predMap.put(predLabels[i], i);
+ }
+
+ return predMap;
+ }
+
+ @Override
+ public boolean isSortAndMerge() {
+ return true;
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+
+public class AdaptiveLogisticRegressionTrainer extends AbstractOnlineLearnerTrainer {
+
+ public AdaptiveLogisticRegressionTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ @Override
+ public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+
+ // TODO: Lets use the predMap here as well for encoding
+ int numberOfOutcomes = indexer.getOutcomeLabels().length;
+ int numberOfFeatures = indexer.getPredLabels().length;
+
+ AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(numberOfOutcomes,
+ numberOfFeatures, new L1());
+
+ // TODO: Make these parameters configurable ...
+ // what are good values ?!
+ pa.setInterval(800);
+ pa.setAveragingWindow(500);
+
+ for (int k = 0; k < iterations; k++) {
+ trainOnlineLearner(indexer, pa);
+
+ // What should be reported at the end of every iteration ?!
+ System.out.println("Iteration " + (k + 1));
+ }
+
+ pa.close();
+
+ return new VectorClassifierModel(pa.getBest().getPayload().getLearner(),
+ indexer.getOutcomeLabels(), createPrepMap(indexer));
+ }
+
+ @Override
+ public boolean isSortAndMerge() {
+ return true;
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+public class LogisticRegressionTrainer extends AbstractOnlineLearnerTrainer {
+
+ public LogisticRegressionTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ @Override
+ public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+
+ // TODO: Lets use the predMap here as well for encoding
+
+ int outcomes[] = indexer.getOutcomeList();
+
+ int cardinality = indexer.getPredLabels().length;
+
+
+ AdaptiveLogisticRegression pa = new AdaptiveLogisticRegression(indexer.getOutcomeLabels().length,
+ cardinality, new L1());
+
+ pa.setInterval(800);
+ pa.setAveragingWindow(500);
+
+// PassiveAggressive pa = new PassiveAggressive(indexer.getOutcomeLabels().length, cardinality);
+// pa.learningRate(10000);
+
+// OnlineLogisticRegression pa = new OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality,
+// new L1());
+//
+// pa.alpha(1).stepOffset(250)
+// .decayExponent(0.9)
+// .lambda(3.0e-5)
+// .learningRate(3000);
+
+ // TODO: Should we do both ?! AdaptiveLogisticRegression ?!
+
+ for (int k = 0; k < iterations; k++) {
+ trainOnlineLearner(indexer, pa);
+
+ // What should be reported at the end of every iteration ?!
+ System.out.println("Iteration " + (k + 1));
+ }
+
+ pa.close();
+
+ Map<String, Integer> predMap = new HashMap<String, Integer>();
+
+ String predLabels[] = indexer.getPredLabels();
+ for (int i = 0; i < predLabels.length; i++) {
+ predMap.put(predLabels[i], i);
+ }
+
+ return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), indexer.getOutcomeLabels(), predMap);
+
+// return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), predMap);
+ }
+
+ @Override
+ public boolean isSortAndMerge() {
+ return true;
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+
+public class OnlineLogisticRegressionTrainer extends AbstractOnlineLearnerTrainer {
+
+ public OnlineLogisticRegressionTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ @Override
+ public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+
+ // TODO: Lets use the predMap here as well for encoding
+ int numberOfOutcomes = indexer.getOutcomeLabels().length;
+ int numberOfFeatures = indexer.getPredLabels().length;
+
+ // TODO: Make these parameters configurable ...
+ OnlineLogisticRegression pa = new OnlineLogisticRegression(
+ numberOfOutcomes, numberOfFeatures, new L1());
+
+ pa.alpha(1).stepOffset(250).decayExponent(0.9).lambda(3.0e-5)
+ .learningRate(3000);
+
+ for (int k = 0; k < iterations; k++) {
+ trainOnlineLearner(indexer, pa);
+
+ // What should be reported at the end of every iteration ?!
+ System.out.println("Iteration " + (k + 1));
+ }
+
+ pa.close();
+
+ return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), createPrepMap(indexer));
+ }
+
+ @Override
+ public boolean isSortAndMerge() {
+ return true;
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+
+public class PassiveAggressiveTrainer extends AbstractOnlineLearnerTrainer {
+
+ public PassiveAggressiveTrainer(Map<String, String> trainParams,
+ Map<String, String> reportMap) {
+ super(trainParams, reportMap);
+ }
+
+ @Override
+ public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+
+ // TODO: Lets use the predMap here as well for encoding
+ int numberOfOutcomes = indexer.getOutcomeLabels().length;
+ int numberOfFeatures = indexer.getPredLabels().length;
+
+ PassiveAggressive pa = new PassiveAggressive(numberOfOutcomes, numberOfFeatures);
+
+ for (int k = 0; k < iterations; k++) {
+ trainOnlineLearner(indexer, pa);
+
+ // What should be reported at the end of every iteration ?!
+ System.out.println("Iteration " + (k + 1));
+ }
+
+ pa.close();
+
+ return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), createPrepMap(indexer));
+ }
+
+ @Override
+ public boolean isSortAndMerge() {
+ return true;
+ }
+}
Added: opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java
URL: http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java Wed Jan 8 20:39:08 2014
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.addons.mahout;
+
+import java.util.Map;
+
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+// TODO: Would be nice to have an abstract maxent model impl ..
+
+public class VectorClassifierModel implements MaxentModel {
+
+ private final AbstractVectorClassifier classifier;
+ private final String[] outcomeLabels;
+ private final Map<String, Integer> predMap;
+
+ public VectorClassifierModel(AbstractVectorClassifier pa, String outcomeLabels[],
+ Map<String, Integer> predMap) {
+ this.classifier = pa;
+ // TODO: We should make a copy, so the model is immutable ...
+ this.outcomeLabels = outcomeLabels;
+ this.predMap = predMap;
+ }
+
+ public double[] eval(String[] features) {
+ Vector vector = new RandomAccessSparseVector(predMap.size());
+
+ for (String feature : features) {
+ Integer featureId = predMap.get(feature);
+
+ if (featureId != null) {
+ vector.set(featureId, vector.get(featureId) + 1);
+ }
+ }
+
+ Vector resultVector = classifier.classifyFull(vector);
+
+ double outcomes[] = new double[classifier.numCategories()];
+
+ for (int i = 0; i < outcomes.length; i++) {
+ outcomes[i] = resultVector.get(i);
+ }
+
+ return outcomes;
+ }
+
+ public double[] eval(String[] context, double[] probs) {
+ return eval(context);
+ }
+
+ public double[] eval(String[] context, float[] values) {
+ return eval(context);
+ }
+
+ @Override
+ public String getBestOutcome(double[] ocs) {
+ int best = 0;
+ for (int i = 1; i < ocs.length; i++)
+ if (ocs[i] > ocs[best]) best = i;
+ return outcomeLabels[best];
+ }
+
+ @Override
+ public String getAllOutcomes(double[] outcomes) {
+ return null;
+ }
+
+ @Override
+ public String getOutcome(int i) {
+ return outcomeLabels[i];
+ }
+
+ @Override
+ public int getIndex(String outcome) {
+ for (int i = 0; i < outcomeLabels.length; i++) {
+ if (outcomeLabels[i].equals(outcome)) {
+ return i;
+ }
+ }
+
+ return -1;
+ }
+
+ @Override
+ public int getNumOutcomes() {
+ return outcomeLabels.length;
+ }
+}