You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by jo...@apache.org on 2012/10/02 16:52:35 UTC
svn commit: r1392944 [1/2] - in /opennlp/trunk/opennlp-maxent/src:
main/java/opennlp/maxent/io/ main/java/opennlp/maxent/quasinewton/
main/java/opennlp/model/ test/java/opennlp/maxent/quasinewton/
Author: joern
Date: Tue Oct 2 14:52:34 2012
New Revision: 1392944
URL: http://svn.apache.org/viewvc?rev=1392944&view=rev
Log:
OPENNLP-338 Added experimental support for L-BFGS training. Thanks to Hyosup Shim for providing a patch.
Added:
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java (with props)
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java (with props)
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java (with props)
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java (with props)
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java (with props)
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java (with props)
opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction02.java (with props)
Modified:
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/AbstractModel.java
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelReader.java
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelWriter.java
opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.io;
+
+import java.io.DataInputStream;
+
+import opennlp.model.BinaryFileDataReader;
+
+/**
+ * A reader for quasi-newton models stored in binary format.
+ */
+public class BinaryQNModelReader extends QNModelReader {
+
+ /**
+ * Constructor which directly instantiates the DataInputStream containing the
+ * model contents.
+ *
+ * @param dis
+ * The DataInputStream containing the model information.
+ */
+ public BinaryQNModelReader(DataInputStream dis) {
+ super(new BinaryFileDataReader(dis));
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelReader.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.io;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.zip.GZIPOutputStream;
+
+import opennlp.model.AbstractModel;
+
+public class BinaryQNModelWriter extends QNModelWriter {
+ protected DataOutputStream output;
+
+ /**
+ * Constructor which takes a GISModel and a File and prepares itself to write
+ * the model to that file. Detects whether the file is gzipped or not based on
+ * whether the suffix contains ".gz".
+ *
+ * @param model
+ * The GISModel which is to be persisted.
+ * @param f
+ * The File in which the model is to be persisted.
+ */
+ public BinaryQNModelWriter(AbstractModel model, File f) throws IOException {
+
+ super(model);
+
+ if (f.getName().endsWith(".gz")) {
+ output = new DataOutputStream(new GZIPOutputStream(
+ new FileOutputStream(f)));
+ } else {
+ output = new DataOutputStream(new FileOutputStream(f));
+ }
+ }
+
+ /**
+ * Constructor which takes a GISModel and a DataOutputStream and prepares
+ * itself to write the model to that stream.
+ *
+ * @param model
+ * The GISModel which is to be persisted.
+ * @param dos
+ * The stream which will be used to persist the model.
+ */
+ public BinaryQNModelWriter(AbstractModel model, DataOutputStream dos) {
+ super(model);
+ output = dos;
+ }
+
+ public void writeUTF(String s) throws java.io.IOException {
+ output.writeUTF(s);
+ }
+
+ public void writeInt(int i) throws java.io.IOException {
+ output.writeInt(i);
+ }
+
+ public void writeDouble(double d) throws java.io.IOException {
+ output.writeDouble(d);
+ }
+
+ public void close() throws java.io.IOException {
+ output.flush();
+ output.close();
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/BinaryQNModelWriter.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.io;
+
+import java.io.ObjectInputStream;
+
+import opennlp.model.ObjectDataReader;
+
+public class ObjectQNModelReader extends QNModelReader {
+
+ public ObjectQNModelReader(ObjectInputStream ois) {
+ super(new ObjectDataReader(ois));
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelReader.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,58 @@
+package opennlp.maxent.io;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+
+import opennlp.model.AbstractModel;
+
+public class ObjectQNModelWriter extends QNModelWriter {
+
+ protected ObjectOutputStream output;
+
+ /**
+ * Constructor which takes a GISModel and a ObjectOutputStream and prepares
+ * itself to write the model to that stream.
+ *
+ * @param model The GISModel which is to be persisted.
+ * @param dos The stream which will be used to persist the model.
+ */
+ public ObjectQNModelWriter(AbstractModel model, ObjectOutputStream dos) {
+ super(model);
+ output = dos;
+ }
+
+ public void writeUTF(String s) throws IOException {
+ output.writeUTF(s);
+ }
+
+ public void writeInt(int i) throws IOException {
+ output.writeInt(i);
+ }
+
+ public void writeDouble(double d) throws IOException {
+ output.writeDouble(d);
+ }
+
+ public void close() throws IOException {
+ output.flush();
+ output.close();
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/ObjectQNModelWriter.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.io;
+
+import java.io.File;
+import java.io.IOException;
+
+import opennlp.maxent.quasinewton.QNModel;
+import opennlp.model.AbstractModel;
+import opennlp.model.AbstractModelReader;
+import opennlp.model.Context;
+import opennlp.model.DataReader;
+
+public class QNModelReader extends AbstractModelReader {
+
+ public QNModelReader(DataReader dataReader) {
+ super(dataReader);
+ }
+
+ public QNModelReader(File file) throws IOException {
+ super(file);
+ }
+
+ @Override
+ public void checkModelType() throws IOException {
+ String modelType = readUTF();
+ if (!modelType.equals("QN"))
+ System.out.println("Error: attempting to load a " + modelType
+ + " model as a MAXENT_QN model." + " You should expect problems.");
+ }
+
+ @Override
+ public AbstractModel constructModel() throws IOException {
+ String[] predNames = getPredicates();
+ String[] outcomeNames = getOutcomes();
+ Context[] params = getParameters();
+ double[] parameters = getDoubleArrayParams();
+ return new QNModel(predNames, outcomeNames, params, parameters);
+ }
+
+ private double[] getDoubleArrayParams() throws IOException {
+ int numDouble = readInt();
+ double[] doubleArray = new double[numDouble];
+ for (int i=0; i < numDouble; i++)
+ doubleArray[i] = readDouble();
+ return doubleArray;
+ }
+
+ private int[] getIntArrayParams() throws IOException {
+ int numInt = readInt();
+ int[] intArray = new int[numInt];
+ for (int i=0; i < numInt; i++)
+ intArray[i] = readInt();
+ return intArray;
+ }
+
+ protected Context[] getParameters() throws java.io.IOException {
+ int numContext = readInt();
+ Context[] params = new Context[numContext];
+
+ for (int i = 0; i < numContext; i++) {
+ int[] outcomePattern = getIntArrayParams();
+ double[] parameters = getDoubleArrayParams();
+ params[i] = new Context(outcomePattern, parameters);
+ }
+ return params;
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelReader.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.io;
+
+import java.io.IOException;
+
+import opennlp.maxent.quasinewton.QNModel;
+import opennlp.model.AbstractModel;
+import opennlp.model.AbstractModelWriter;
+import opennlp.model.Context;
+import opennlp.model.IndexHashTable;
+
+public abstract class QNModelWriter extends AbstractModelWriter {
+ protected String[] outcomeNames;
+ protected String[] predNames;
+ protected Context[] params;
+ protected double[] predParams;
+ //protected EvalParameters evalParam;
+
+ protected IndexHashTable<String> pmap;
+ protected double[] parameters;
+
+ @SuppressWarnings("unchecked")
+ public QNModelWriter(AbstractModel model) {
+ Object[] data = model.getDataStructures();
+ params = (Context[]) data[0];
+ pmap = (IndexHashTable<String>) data[1];
+ outcomeNames = (String[]) data[2];
+
+ QNModel qnModel = (QNModel) model;
+ parameters = qnModel.getParameters();
+ }
+
+ @Override
+ public void persist() throws IOException {
+ // the type of model (QN)
+ writeUTF("QN");
+
+ // predNames
+ predNames = new String[pmap.size()];
+ pmap.toArray(predNames);
+ writeInt(predNames.length);
+ for (int i = 0; i < predNames.length; i++)
+ writeUTF(predNames[i]);
+
+ // outcomeNames
+ writeInt(outcomeNames.length);
+ for (int i = 0; i < outcomeNames.length; i++)
+ writeUTF(outcomeNames[i]);
+
+ // parameters
+ writeInt(params.length);
+ for (Context currContext : params) {
+ writeInt(currContext.getOutcomes().length);
+ for (int i = 0; i < currContext.getOutcomes().length; i++) {
+ writeInt(currContext.getOutcomes()[i]);
+ }
+ writeInt(currContext.getParameters().length);
+ for (int i = 0; i < currContext.getParameters().length; i++) {
+ writeDouble(currContext.getParameters()[i]);
+ }
+ }
+
+ // parameters 2
+ writeInt(parameters.length);
+ for (int i = 0; i < parameters.length; i++)
+ writeDouble(parameters[i]);
+ close();
+ }
+}
+
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/io/QNModelWriter.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * utility class for simple vector arithmetics.
+ */
+public class ArrayMath {
+
+ public static double innerProduct(double[] vecA, double[] vecB) {
+ if (vecA == null || vecB == null)
+ return Double.NaN;
+ if (vecA.length != vecB.length)
+ return Double.NaN;
+
+ double product = 0.0;
+ for (int i = 0; i < vecA.length; i++) {
+ product += vecA[i] * vecB[i];
+ }
+ return product;
+ }
+
+ public static double[] updatePoint(double[] point, double[] vector, double scale) {
+ if (point == null || vector == null)
+ return null;
+ if (point.length != vector.length)
+ return null;
+
+ double[] updated = point.clone();
+ for (int i = 0; i < updated.length; i++) {
+ updated[i] = updated[i] + (vector[i] * scale);
+ }
+ return updated;
+ }
+}
+
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/ArrayMath.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * interface for a function that can be differentiated once.
+ */
+public interface DifferentiableFunction extends Function {
+ public double[] gradientAt(double[] x);
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/DifferentiableFunction.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * interface for a function.
+ */
+public interface Function {
+
+ public double valueAt(double[] x);
+
+ public int getDomainDimension();
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/Function.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * class that performs line search.
+ */
+public class LineSearch {
+ private static final double INITIAL_STEP_SIZE = 1.0;
+ private static final double MIN_STEP_SIZE = 1.0E-10;
+ private static final double C1 = 0.0001;
+ private static final double C2 = 0.9;
+ private static final double TT = 16.0;
+
+
+ public static LineSearchResult doLineSearch(DifferentiableFunction function, double[] direction, LineSearchResult lsr) {
+ return doLineSearch(function, direction, lsr, false);
+ }
+
+ public static LineSearchResult doLineSearch(DifferentiableFunction function, double[] direction, LineSearchResult lsr, boolean verbose) {
+ int currFctEvalCount = lsr.getFctEvalCount();
+ double stepSize = INITIAL_STEP_SIZE;
+ double[] x = lsr.getNextPoint();
+ double valueAtX = lsr.getValueAtNext();
+ double[] gradAtX = lsr.getGradAtNext();
+ double[] nextPoint = null;
+ double[] gradAtNextPoint = null;
+ double valueAtNextPoint = 0.0;
+
+ double mu = 0;
+ double upsilon = Double.POSITIVE_INFINITY;
+
+ long startTime = System.currentTimeMillis();
+ while(true) {
+ nextPoint = ArrayMath.updatePoint(x, direction, stepSize);
+ valueAtNextPoint = function.valueAt(nextPoint);
+ currFctEvalCount++;
+ gradAtNextPoint = function.gradientAt(nextPoint);
+
+ if (!checkArmijoCond(valueAtX, valueAtNextPoint, gradAtX, direction, stepSize, true)) {
+ upsilon = stepSize;
+ } else if(!checkCurvature(gradAtNextPoint, gradAtX, direction, x.length, true)) {
+ mu = stepSize;
+ } else break;
+
+ if (upsilon < Double.POSITIVE_INFINITY)
+ stepSize = (mu + upsilon) / TT;
+ else
+ stepSize *= TT;
+
+ if (stepSize < MIN_STEP_SIZE + mu) {
+ stepSize = 0.0;
+ break;
+ }
+ }
+ long endTime = System.currentTimeMillis();
+ long duration = endTime - startTime;
+
+ if (verbose) {
+ System.out.print("\t" + valueAtX);
+ System.out.print("\t" + (valueAtNextPoint - valueAtX));
+ System.out.print("\t" + (duration / 1000.0) + "\n");
+ }
+
+ LineSearchResult result = new LineSearchResult(stepSize, valueAtX, valueAtNextPoint, gradAtX, gradAtNextPoint, x, nextPoint, currFctEvalCount);
+ return result;
+ }
+
+ private static boolean checkArmijoCond(double valueAtX, double valueAtNewPoint,
+ double[] gradAtX, double[] direction, double currStepSize, boolean isMaximizing) {
+ // check Armijo rule;
+ // f(x_k + a_kp_k) <= f(x_k) + c_1a_kp_k^t grad(xk)
+ double armijo = valueAtX + (C1 * ArrayMath.innerProduct(direction, gradAtX) * currStepSize);
+ return isMaximizing ? valueAtNewPoint > armijo: valueAtNewPoint <= armijo;
+ }
+
+ // check weak wolfe condition
+ private static boolean checkCurvature(double[] gradAtNewPoint, double[] gradAtX,
+ double[] direction, int domainDimension, boolean isMaximizing) {
+ // check curvature condition.
+ double curvature01 = ArrayMath.innerProduct(direction, gradAtNewPoint);
+ double curvature02 = C2 * ArrayMath.innerProduct(direction, gradAtX);
+ return isMaximizing ? curvature01 < curvature02 : curvature01 >= curvature02;
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearch.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * class to store lineSearch result
+ */
+public class LineSearchResult {
+ public static LineSearchResult getInitialObject(double valueAtX, double[] gradAtX, double[] x, int maxFctEval) {
+ return new LineSearchResult(0.0, 0.0, valueAtX, null, gradAtX, null, x, maxFctEval);
+ }
+
+ public static LineSearchResult getInitialObject(double valueAtX, double[] gradAtX, double[] x) {
+ return new LineSearchResult(0.0, 0.0, valueAtX, null, gradAtX, null, x, QNTrainer.DEFAULT_MAX_FCT_EVAL);
+ }
+
+ private int fctEvalCount;
+ private double stepSize;
+ private double valueAtCurr;
+ private double valueAtNext;
+ private double[] gradAtCurr;
+ private double[] gradAtNext;
+ private double[] currPoint;
+ private double[] nextPoint;
+
+ public LineSearchResult(double stepSize, double valueAtX, double valurAtX_1,
+ double[] gradAtX, double[] gradAtX_1, double[] currPoint, double[] nextPoint, int fctEvalCount) {
+ this.stepSize = stepSize;
+ this.valueAtCurr = valueAtX;
+ this.valueAtNext = valurAtX_1;
+ this.gradAtCurr = gradAtX;
+ this.gradAtNext = gradAtX_1;
+ this.currPoint = currPoint;
+ this.nextPoint = nextPoint;
+ this.setFctEvalCount(fctEvalCount);
+ }
+
+ public double getStepSize() {
+ return stepSize;
+ }
+ public void setStepSize(double stepSize) {
+ this.stepSize = stepSize;
+ }
+ public double getValueAtCurr() {
+ return valueAtCurr;
+ }
+ public void setValueAtCurr(double valueAtCurr) {
+ this.valueAtCurr = valueAtCurr;
+ }
+ public double getValueAtNext() {
+ return valueAtNext;
+ }
+ public void setValueAtNext(double valueAtNext) {
+ this.valueAtNext = valueAtNext;
+ }
+ public double[] getGradAtCurr() {
+ return gradAtCurr;
+ }
+ public void setGradAtCurr(double[] gradAtCurr) {
+ this.gradAtCurr = gradAtCurr;
+ }
+ public double[] getGradAtNext() {
+ return gradAtNext;
+ }
+ public void setGradAtNext(double[] gradAtNext) {
+ this.gradAtNext = gradAtNext;
+ }
+ public double[] getCurrPoint() {
+ return currPoint;
+ }
+ public void setCurrPoint(double[] currPoint) {
+ this.currPoint = currPoint;
+ }
+ public double[] getNextPoint() {
+ return nextPoint;
+ }
+ public void setNextPoint(double[] nextPoint) {
+ this.nextPoint = nextPoint;
+ }
+ public int getFctEvalCount() {
+ return fctEvalCount;
+ }
+ public void setFctEvalCount(int fctEvalCount) {
+ this.fctEvalCount = fctEvalCount;
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LineSearchResult.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import opennlp.model.DataIndexer;
+import opennlp.model.OnePassRealValueDataIndexer;
+
+/**
+ * Evaluate log likelihood and its gradient from DataIndexer.
+ */
+public class LogLikelihoodFunction implements DifferentiableFunction {
+ private int domainDimension;
+ private double value;
+ private double[] gradient;
+ private double[] lastX;
+ private double[] empiricalCount;
+ private int numOutcomes;
+ private int numFeatures;
+ private int numContexts;
+ private double[][] probModel;
+
+ private String[] outcomeLabels;
+ private String[] predLabels;
+
+ private int[][] outcomePatterns;
+
+ // infos from data index;
+ private final float[][] values;
+ private final int[][] contexts;
+ private final int[] outcomeList;
+ private final int[] numTimesEventsSeen;
+
+ public LogLikelihoodFunction(DataIndexer indexer) {
+ // get data from indexer.
+ if (indexer instanceof OnePassRealValueDataIndexer) {
+ this.values = indexer.getValues();
+ } else {
+ this.values = null;
+ }
+
+ this.contexts = indexer.getContexts();
+ this.outcomeList = indexer.getOutcomeList();
+ this.numTimesEventsSeen = indexer.getNumTimesEventsSeen();
+
+ this.outcomeLabels = indexer.getOutcomeLabels();
+ this.predLabels = indexer.getPredLabels();
+
+ this.numOutcomes = indexer.getOutcomeLabels().length;
+ this.numFeatures = indexer.getPredLabels().length;
+ this.numContexts = this.contexts.length;
+ this.domainDimension = numOutcomes * numFeatures;
+ this.probModel = new double[numContexts][numOutcomes];
+ this.gradient = null;
+ }
+
+ public double valueAt(double[] x) {
+ if (!checkLastX(x)) calculate(x);
+ return value;
+ }
+
+ public double[] gradientAt(double[] x) {
+ if (!checkLastX(x)) calculate(x);
+ return gradient;
+ }
+
+ public int getDomainDimension() {
+ return this.domainDimension;
+ }
+
+ public double[] getInitialPoint() {
+ return new double[domainDimension];
+ }
+
+ public String[] getPredLabels() {
+ return this.predLabels;
+ }
+
+ public String[] getOutcomeLabels() {
+ return this.outcomeLabels;
+ }
+
+ public int[][] getOutcomePatterns() {
+ return this.outcomePatterns;
+ }
+
+ private void calculate(double[] x) {
+ if (x.length != this.domainDimension) {
+ throw new IllegalArgumentException("x is invalid, its dimension is not equal to the function.");
+ }
+
+ initProbModel();
+ if (this.empiricalCount == null)
+ initEmpCount();
+
+ // sum up log likelihood and empirical feature count for gradient calculation.
+ double logLikelihood = 0.0;
+
+ for (int ci = 0; ci < numContexts; ci++) {
+ double voteSum = 0.0;
+
+ for (int af = 0; af < this.contexts[ci].length; af++) {
+ int vectorIndex = indexOf(this.outcomeList[ci], contexts[ci][af]);
+ double predValue = 1.0;
+ if (values != null) predValue = this.values[ci][af];
+ if (predValue == 0.0) continue;
+
+ voteSum += predValue * x[vectorIndex];
+ }
+ probModel[ci][this.outcomeList[ci]] = Math.exp(voteSum);
+
+ double totalVote = 0.0;
+ for (int i = 0; i < numOutcomes; i++) {
+ totalVote += probModel[ci][i];
+ }
+ for (int i = 0; i < numOutcomes; i++) {
+ probModel[ci][i] /= totalVote;
+ }
+ for (int i = 0; i < numTimesEventsSeen[ci]; i++) {
+ logLikelihood += Math.log(probModel[ci][this.outcomeList[ci]]);
+ }
+ }
+ this.value = logLikelihood;
+
+ // calculate gradient.
+ double[] expectedCount = new double[numOutcomes * numFeatures];
+ for (int ci = 0; ci < numContexts; ci++) {
+ for (int oi = 0; oi < numOutcomes; oi++) {
+ for (int af = 0; af < contexts[ci].length; af++) {
+ int vectorIndex = indexOf(oi, this.contexts[ci][af]);
+ double predValue = 1.0;
+ if (values != null) predValue = this.values[ci][af];
+ if (predValue == 0.0) continue;
+
+ expectedCount[vectorIndex] += predValue * probModel[ci][oi] * this.numTimesEventsSeen[ci];
+ }
+ }
+ }
+
+ double[] gradient = new double[domainDimension];
+ for (int i = 0; i < numOutcomes * numFeatures; i++) {
+ gradient[i] = expectedCount[i] - this.empiricalCount[i];
+ }
+ this.gradient = gradient;
+
+ // update last evaluated x.
+ this.lastX = x.clone();
+ }
+
+ /**
+ * @param x vector that represents point to evaluate at.
+ * @return check x is whether last evaluated point or not.
+ */
+ private boolean checkLastX(double[] x) {
+ if (this.lastX == null) return false;
+
+ for (int i = 0; i < x.length; i++) {
+ if (lastX[i] != x[i]) return false;
+ }
+ return true;
+ }
+
+ private int indexOf(int outcomeId, int featureId) {
+ return outcomeId * numFeatures + featureId;
+ }
+
+ private void initProbModel() {
+ for (int i = 0; i < this.probModel.length; i++) {
+ Arrays.fill(this.probModel[i], 1.0);
+ }
+ }
+
+ private void initEmpCount() {
+ this.empiricalCount = new double[numOutcomes * numFeatures];
+ this.outcomePatterns = new int[predLabels.length][];
+
+ for (int ci = 0; ci < numContexts; ci++) {
+ for (int af = 0; af < this.contexts[ci].length; af++) {
+ int vectorIndex = indexOf(this.outcomeList[ci], contexts[ci][af]);
+ if (values != null) {
+ empiricalCount[vectorIndex] += this.values[ci][af] * numTimesEventsSeen[ci];
+ } else {
+ empiricalCount[vectorIndex] += 1.0 * numTimesEventsSeen[ci];
+ }
+ }
+ }
+
+ for (int fi = 0; fi < this.outcomePatterns.length; fi++) {
+ ArrayList<Integer> pattern = new ArrayList<Integer>();
+ for (int oi = 0; oi < outcomeLabels.length; oi++) {
+ int countIndex = fi + (this.predLabels.length * oi);
+ if (this.empiricalCount[countIndex] > 0) {
+ pattern.add(oi);
+ }
+ }
+ outcomePatterns[fi] = new int[pattern.size()];
+ for (int i = 0; i < pattern.size(); i++) {
+ outcomePatterns[fi][i] = pattern.get(i);
+ }
+ }
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/LogLikelihoodFunction.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import opennlp.model.AbstractModel;
+import opennlp.model.Context;
+import opennlp.model.EvalParameters;
+import opennlp.model.UniformPrior;
+
+public class QNModel extends AbstractModel {
+ private static final double SMOOTHING_VALUE = 0.1;
+ private double[] parameters;
+ // FROM trainer
+ public QNModel(LogLikelihoodFunction monitor, double[] parameters) {
+ super(null, monitor.getPredLabels(), monitor.getOutcomeLabels());
+
+ int[][] outcomePatterns = monitor.getOutcomePatterns();
+ Context[] params = new Context[monitor.getPredLabels().length];
+ for (int ci = 0; ci < params.length; ci++) {
+ int[] outcomePattern = outcomePatterns[ci];
+ double[] alpha = new double[outcomePattern.length];
+ for (int oi = 0; oi < outcomePattern.length; oi++) {
+ alpha[oi] = parameters[ci + (outcomePattern[oi] * monitor.getPredLabels().length)];
+ }
+ params[ci] = new Context(outcomePattern, alpha);
+ }
+ this.evalParams = new EvalParameters(params, monitor.getOutcomeLabels().length);
+ this.prior = new UniformPrior();
+ this.modelType = ModelType.MaxentQn;
+
+ this.parameters = parameters;
+ }
+
+ // FROM model reader
+ public QNModel(String[] predNames, String[] outcomeNames, Context[] params, double[] parameters) {
+ super(params, predNames, outcomeNames);
+ this.prior = new UniformPrior();
+ this.modelType = ModelType.MaxentQn;
+
+ this.parameters = parameters;
+ }
+
+ public double[] eval(String[] context) {
+ return eval(context, new double[evalParams.getNumOutcomes()]);
+ }
+
+ private int getPredIndex(String predicate) {
+ return pmap.get(predicate);
+ }
+
+ public double[] eval(String[] context, double[] probs) {
+ return eval(context, null, probs);
+ }
+
+ public double[] eval(String[] context, float[] values) {
+ return eval(context, values, new double[evalParams.getNumOutcomes()]);
+ }
+
+ // TODO need implments for handlling with "probs".
+ private double[] eval(String[] context, float[] values, double[] probs) {
+ double[] result = new double[outcomeNames.length];
+ double[] table = new double[outcomeNames.length + 1];
+ for (int pi = 0; pi < context.length; pi++) {
+ int predIdx = getPredIndex(context[pi]);
+
+ for (int oi = 0; oi < outcomeNames.length; oi++) {
+ int paraIdx = oi * pmap.size() + predIdx;
+
+ double predValue = 1.0;
+ if (values != null) predValue = values[pi];
+ if (paraIdx < 0) {
+ table[oi] += predValue * SMOOTHING_VALUE;
+ } else {
+ table[oi] += predValue * parameters[paraIdx];
+ }
+
+ }
+ }
+
+ for (int oi = 0; oi < outcomeNames.length; oi++) {
+ table[oi] = Math.exp(table[oi]);
+ table[outcomeNames.length] += table[oi];
+ }
+ for (int oi = 0; oi < outcomeNames.length; oi++) {
+ result[oi] = table[oi] / table[outcomeNames.length];
+ }
+ return result;
+// double[] table = new double[outcomeNames.length];
+// Arrays.fill(table, 1.0 / outcomeNames.length);
+// return table;
+ }
+
+ public int getNumOutcomes() {
+ return this.outcomeNames.length;
+ }
+
+ public double[] getParameters() {
+ return this.parameters;
+ }
+
+ public boolean equals(Object obj) {
+ if (!(obj instanceof QNModel))
+ return false;
+
+ QNModel objModel = (QNModel) obj;
+ if (this.outcomeNames.length != objModel.outcomeNames.length)
+ return false;
+ for (int i = 0; i < this.outcomeNames.length; i++) {
+ if (!this.outcomeNames[i].equals(objModel.outcomeNames[i]))
+ return false;
+ }
+
+ if (this.pmap.size() != objModel.pmap.size())
+ return false;
+ String[] pmapArray = new String[pmap.size()];
+ pmap.toArray(pmapArray);
+ for (int i = 0; i < this.pmap.size(); i++) {
+ if (i != objModel.pmap.get(pmapArray[i]))
+ return false;
+ }
+
+ // compare evalParameters
+ Context[] contextComparing = objModel.evalParams.getParams();
+ if (this.evalParams.getParams().length != contextComparing.length)
+ return false;
+ for (int i = 0; i < this.evalParams.getParams().length; i++) {
+ if (this.evalParams.getParams()[i].getOutcomes().length != contextComparing[i].getOutcomes().length)
+ return false;
+ for (int j = 0; i < this.evalParams.getParams()[i].getOutcomes().length; i++) {
+ if (this.evalParams.getParams()[i].getOutcomes()[j] != contextComparing[i].getOutcomes()[j])
+ return false;
+ }
+
+ if (this.evalParams.getParams()[i].getParameters().length != contextComparing[i].getParameters().length)
+ return false;
+ for (int j = 0; i < this.evalParams.getParams()[i].getParameters().length; i++) {
+ if (this.evalParams.getParams()[i].getParameters()[j] != contextComparing[i].getParameters()[j])
+ return false;
+ }
+ }
+ return true;
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNModel.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java (added)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import java.util.Arrays;
+
+import opennlp.model.DataIndexer;
+
+/**
+ * maxent model trainer using l-bfgs algorithm.
+ */
+public class QNTrainer {
+ // constants for optimization.
+ private static final double CONVERGE_TOLERANCE = 1.0E-10;
+ private static final int MAX_M = 15;
+ public static final int DEFAULT_M = 7;
+ public static final int MAX_FCT_EVAL = 3000;
+ public static final int DEFAULT_MAX_FCT_EVAL = 300;
+
+ // settings for objective function and optimizer.
+ private int dimension;
+ private int m;
+ private int maxFctEval;
+ private QNInfo updateInfo;
+ private boolean verbose;
+
+ // default constructor -- no log.
+ public QNTrainer() {
+ this(true);
+ }
+
+ // constructor -- to log.
+ public QNTrainer(boolean verbose) {
+ this(DEFAULT_M, verbose);
+ }
+
+ // constructor -- m : number of hessian updates to store.
+ public QNTrainer(int m) {
+ this(m, true);
+ }
+
+ // constructor -- to log, number of hessian updates to store.
+ public QNTrainer(int m, boolean verbose) {
+ this(m, DEFAULT_MAX_FCT_EVAL, verbose);
+ }
+
+ public QNTrainer(int m, int maxFctEval, boolean verbose) {
+ this.verbose = verbose;
+ if (m > MAX_M) {
+ this.m = MAX_M;
+ } else {
+ this.m = m;
+ }
+ if (maxFctEval < 0) {
+ this.maxFctEval = DEFAULT_MAX_FCT_EVAL;
+ } else if (maxFctEval > MAX_FCT_EVAL) {
+ this.maxFctEval = MAX_FCT_EVAL;
+ } else {
+ this.maxFctEval = maxFctEval;
+ }
+ }
+
+ public QNModel trainModel(DataIndexer indexer) {
+ LogLikelihoodFunction objectiveFunction = generateFunction(indexer);
+ this.dimension = objectiveFunction.getDomainDimension();
+ this.updateInfo = new QNInfo(this.m, this.dimension);
+
+ double[] initialPoint = objectiveFunction.getInitialPoint();
+ double initialValue = objectiveFunction.valueAt(initialPoint);
+ double[] initialGrad = objectiveFunction.gradientAt(initialPoint);
+
+ LineSearchResult lsr = LineSearchResult.getInitialObject(initialValue, initialGrad, initialPoint, 0);
+
+ int z = 0;
+ while (true) {
+ if (verbose) {
+ System.out.print(z++);
+ }
+ double[] direction = null;
+
+ direction = computeDirection(objectiveFunction, lsr);
+ lsr = LineSearch.doLineSearch(objectiveFunction, direction, lsr, verbose);
+
+ updateInfo.updateInfo(lsr);
+
+ if (isConverged(lsr))
+ break;
+ }
+ return new QNModel(objectiveFunction, lsr.getNextPoint());
+ }
+
+
+ private LogLikelihoodFunction generateFunction(DataIndexer indexer) {
+ return new LogLikelihoodFunction(indexer);
+ }
+
+ private double[] computeDirection(DifferentiableFunction monitor, LineSearchResult lsr) {
+ // implemented two-loop hessian update method.
+ double[] direction = lsr.getGradAtNext().clone();
+ double[] as = new double[m];
+
+ // first loop
+ for (int i = updateInfo.kCounter - 1; i >= 0; i--) {
+ as[i] = updateInfo.getRho(i) * ArrayMath.innerProduct(updateInfo.getS(i), direction);
+ for (int ii = 0; ii < dimension; ii++) {
+ direction[ii] = direction[ii] - as[i] * updateInfo.getY(i)[ii];
+ }
+ }
+
+ // second loop
+ for (int i = 0; i < updateInfo.kCounter; i++) {
+ double b = updateInfo.getRho(i) * ArrayMath.innerProduct(updateInfo.getY(i), direction);
+ for (int ii = 0; ii < dimension; ii++) {
+ direction[ii] = direction[ii] + (as[i] - b) * updateInfo.getS(i)[ii];
+ }
+ }
+
+ for (int i = 0; i < dimension; i++) {
+ direction[i] *= -1.0;
+ }
+
+ return direction;
+ }
+
+ // FIXME need an improvement in convergence condition
+ private boolean isConverged(LineSearchResult lsr) {
+ return CONVERGE_TOLERANCE > Math.abs(lsr.getValueAtNext() - lsr.getValueAtCurr())
+ || lsr.getFctEvalCount() > this.maxFctEval;
+ }
+
+ /**
+ * class to store vectors for hessian approximation update.
+ */
+ private class QNInfo {
+ private double[][] S;
+ private double[][] Y;
+ private double[] rho;
+ private int m;
+ private double[] diagonal;
+
+ private int kCounter;
+
+ // constructor
+ QNInfo(int numCorrection, int dimension) {
+ this.m = numCorrection;
+ this.kCounter = 0;
+ S = new double[this.m][];
+ Y = new double[this.m][];
+ rho = new double[this.m];
+ Arrays.fill(rho, Double.NaN);
+ diagonal = new double[dimension];
+ Arrays.fill(diagonal, 1.0);
+ }
+
+ public void updateInfo(LineSearchResult lsr) {
+ double[] s_k = new double[dimension];
+ double[] y_k = new double[dimension];
+ for (int i = 0; i < dimension; i++) {
+ s_k[i] = lsr.getNextPoint()[i] - lsr.getCurrPoint()[i];
+ y_k[i] = lsr.getGradAtNext()[i] - lsr.getGradAtCurr()[i];
+ }
+ this.updateSYRoh(s_k, y_k);
+ kCounter = kCounter < m ? kCounter + 1 : kCounter;
+ }
+
+ private void updateSYRoh(double[] s_k, double[] y_k) {
+ double newRoh = 1.0 / ArrayMath.innerProduct(y_k, s_k);
+ // add new ones.
+ if (kCounter < m) {
+ S[kCounter] = s_k.clone();
+ Y[kCounter] = y_k.clone();
+ rho[kCounter] = newRoh;
+ } else if (m > 0) {
+ // discard oldest vectors and add new ones.
+ for (int i = 0; i < m - 1; i++) {
+ S[i] = S[i + 1];
+ Y[i] = Y[i + 1];
+ rho[i] = rho[i + 1];
+ }
+ S[m - 1] = s_k.clone();
+ Y[m - 1] = y_k.clone();
+ rho[m - 1] = newRoh;
+ }
+ }
+
+ public double getRho(int updateIndex) {
+ return this.rho[updateIndex];
+ }
+
+ public double[] getS(int updateIndex) {
+ return S[updateIndex];
+ }
+
+ public double[] getY(int updateIndex) {
+ return Y[updateIndex];
+ }
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/maxent/quasinewton/QNTrainer.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/AbstractModel.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/AbstractModel.java?rev=1392944&r1=1392943&r2=1392944&view=diff
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/AbstractModel.java (original)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/AbstractModel.java Tue Oct 2 14:52:34 2012
@@ -32,7 +32,7 @@ public abstract class AbstractModel impl
/** Prior distribution for this model. */
protected Prior prior;
- public enum ModelType {Maxent,Perceptron};
+ public enum ModelType {Maxent,Perceptron,MaxentQn};
/** The type of the model. */
protected ModelType modelType;
Modified: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelReader.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelReader.java?rev=1392944&r1=1392943&r2=1392944&view=diff
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelReader.java (original)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelReader.java Tue Oct 2 14:52:34 2012
@@ -23,6 +23,7 @@ import java.io.File;
import java.io.IOException;
import opennlp.maxent.io.GISModelReader;
+import opennlp.maxent.io.QNModelReader;
import opennlp.perceptron.PerceptronModelReader;
public class GenericModelReader extends AbstractModelReader {
@@ -45,6 +46,9 @@ public class GenericModelReader extends
else if (modelType.equals("GIS")) {
delegateModelReader = new GISModelReader(this.dataReader);
}
+ else if (modelType.equals("QN")) {
+ delegateModelReader = new QNModelReader(this.dataReader);
+ }
else {
throw new IOException("Unknown model format: "+modelType);
}
Modified: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelWriter.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelWriter.java?rev=1392944&r1=1392943&r2=1392944&view=diff
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelWriter.java (original)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/GenericModelWriter.java Tue Oct 2 14:52:34 2012
@@ -29,6 +29,7 @@ import java.io.OutputStreamWriter;
import java.util.zip.GZIPOutputStream;
import opennlp.maxent.io.BinaryGISModelWriter;
+import opennlp.maxent.io.BinaryQNModelWriter;
import opennlp.maxent.io.PlainTextGISModelWriter;
import opennlp.model.AbstractModel.ModelType;
import opennlp.perceptron.BinaryPerceptronModelWriter;
@@ -70,6 +71,9 @@ public class GenericModelWriter extends
else if (model.getModelType() == ModelType.Maxent) {
delegateWriter = new BinaryGISModelWriter(model,dos);
}
+ else if (model.getModelType() == ModelType.MaxentQn) {
+ delegateWriter = new BinaryQNModelWriter(model,dos);
+ }
}
private void init(AbstractModel model, BufferedWriter bw) {
@@ -105,5 +109,4 @@ public class GenericModelWriter extends
public void writeUTF(String s) throws IOException {
delegateWriter.writeUTF(s);
}
-
}
Modified: opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java?rev=1392944&r1=1392943&r2=1392944&view=diff
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-maxent/src/main/java/opennlp/model/TrainUtil.java Tue Oct 2 14:52:34 2012
@@ -20,9 +20,9 @@
package opennlp.model;
import java.io.IOException;
-import java.util.HashMap;
import java.util.Map;
+import opennlp.maxent.quasinewton.QNTrainer;
import opennlp.perceptron.PerceptronTrainer;
import opennlp.perceptron.SimplePerceptronSequenceTrainer;
@@ -31,6 +31,7 @@ public class TrainUtil {
public static final String ALGORITHM_PARAM = "Algorithm";
public static final String MAXENT_VALUE = "MAXENT";
+ public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
public static final String PERCEPTRON_VALUE = "PERCEPTRON";
public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
@@ -99,7 +100,8 @@ public class TrainUtil {
String algorithmName = trainParams.get(ALGORITHM_PARAM);
- if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
+ if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
+ MAXENT_QN_VALUE.equals(algorithmName) ||
PERCEPTRON_VALUE.equals(algorithmName) ||
PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
return false;
@@ -150,8 +152,8 @@ public class TrainUtil {
boolean sortAndMerge;
- if (MAXENT_VALUE.equals(algorithmName))
- sortAndMerge = true;
+ if (MAXENT_VALUE.equals(algorithmName) || MAXENT_QN_VALUE.equals(algorithmName))
+ sortAndMerge = true;
else if (PERCEPTRON_VALUE.equals(algorithmName))
sortAndMerge = false;
else
@@ -182,6 +184,11 @@ public class TrainUtil {
model = opennlp.maxent.GIS.trainModel(iterations, indexer,
true, false, null, 0, threads);
}
+ else if (MAXENT_QN_VALUE.equals(algorithmName)) {
+ int m = getIntParam(trainParams, "numOfUpdates", QNTrainer.DEFAULT_M, reportMap);
+ int maxFctEval = getIntParam(trainParams, "maxFctEval", QNTrainer.DEFAULT_MAX_FCT_EVAL, reportMap);
+ model = new QNTrainer(m, maxFctEval, true).trainModel(indexer);
+ }
else if (PERCEPTRON_VALUE.equals(algorithmName)) {
boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
Added: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java (added)
+++ opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+
+import org.junit.Test;
+
+public class LineSearchTest {
+ public static final double TOLERANCE = 0.01;
+
+ @Test
+ public void testLineSearchDeterminesSaneStepLength01() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction();
+ // given
+ double[] testX = new double[] { 0 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { 1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertTrue(succCond);
+ }
+
+ @Test
+ public void testLineSearchDeterminesSaneStepLength02() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction02();
+ // given
+ double[] testX = new double[] { -2 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { 1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertTrue(succCond);
+ }
+
+ @Test
+ public void testLineSearchFailsWithWrongDirection01() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction();
+ // given
+ double[] testX = new double[] { 0 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { -1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+
+ @Test
+ public void testLineSearchFailsWithWrongDirection02() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction02();
+ // given
+ double[] testX = new double[] { -2 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { -1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+
+ @Test
+ public void testLineSearchFailsWithWrongDirection03() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction();
+ // given
+ double[] testX = new double[] { 4 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { 1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+
+ @Test
+ public void testLineSearchFailsWithWrongDirection04() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction02();
+ // given
+ double[] testX = new double[] { 2 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { 1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+
+ @Test
+ public void testLineSearchFailsAtMaxima01() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction02();
+ // given
+ double[] testX = new double[] { 0 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { -1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+
+ @Test
+ public void testLineSearchFailsAtMaxima02() {
+ DifferentiableFunction objectiveFunction = new QuadraticFunction02();
+ // given
+ double[] testX = new double[] { 0 };
+ double testValueX = objectiveFunction.valueAt(testX);
+ double[] testGradX = objectiveFunction.gradientAt(testX);
+ double[] testDirection = new double[] { 1 };
+ // when
+ LineSearchResult lsr = LineSearchResult.getInitialObject(testValueX, testGradX, testX);
+ double stepSize = LineSearch.doLineSearch(objectiveFunction, testDirection, lsr).getStepSize();
+ // then
+ boolean succCond = TOLERANCE < stepSize && stepSize <= 1;
+ assertFalse(succCond);
+ assertEquals(0.0, stepSize, TOLERANCE);
+ }
+}
\ No newline at end of file
Propchange: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LineSearchTest.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java (added)
+++ opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+
+import opennlp.model.DataIndexer;
+import opennlp.model.OnePassRealValueDataIndexer;
+import opennlp.model.RealValueFileEventStream;
+
+import org.junit.Test;
+
+public class LogLikelihoodFunctionTest {
+ public final double TOLERANCE01 = 1.0E-06;
+ public final double TOLERANCE02 = 1.0E-10;
+
+ @Test
+ public void testDomainDimensionSanity() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ int correctDomainDimension = testDataIndexer.getPredLabels().length * testDataIndexer.getOutcomeLabels().length;
+ // then
+ assertEquals(correctDomainDimension, objectFunction.getDomainDimension());
+ }
+
+ @Test
+ public void testInitialSanity() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] initial = objectFunction.getInitialPoint();
+ // then
+ for (int i = 0; i < initial.length; i++) {
+ assertEquals(0.0, initial[i], TOLERANCE01);
+ }
+ }
+
+ @Test
+ public void testGradientSanity() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] initial = objectFunction.getInitialPoint();
+ double[] gradientAtInitial = objectFunction.gradientAt(initial);
+ // then
+ assertNotNull(gradientAtInitial);
+ }
+
+ @Test
+ public void testValueAtInitialPoint() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double value = objectFunction.valueAt(objectFunction.getInitialPoint());
+ double expectedValue = -13.86294361;
+ // then
+ assertEquals(expectedValue, value, TOLERANCE01);
+ }
+
+ @Test
+ public void testValueAtNonInitialPoint01() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] nonInitialPoint = new double[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
+ double value = objectFunction.valueAt(nonInitialPoint);
+ double expectedValue = -0.000206886;
+ // then
+ assertEquals(expectedValue, value, TOLERANCE01);
+ }
+
+ @Test
+ public void testValueAtNonInitialPoint02() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] nonInitialPoint = new double[] { 2, 3, 2, 3, 3, 3, 2, 3, 2, 2 };
+ double value = objectFunction.valueAt(nonInitialPoint);
+ double expectedValue = -0.00000000285417;
+ // then
+ assertEquals(expectedValue, value, TOLERANCE02);
+ }
+
+ @Test
+ public void testGradientAtInitialPoint() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] gradientAtInitialPoint = objectFunction.gradientAt(objectFunction.getInitialPoint());
+ double[] expectedGradient = new double[] { 20, 8.5, -14, -17, -9, -20, -8.5, 14, 17, 9 };
+ // then
+ assertTrue(expectedGradient.length == gradientAtInitialPoint.length);
+ for (int i = 0; i < expectedGradient.length; i++) {
+ assertEquals(expectedGradient[i], gradientAtInitialPoint[i], TOLERANCE01);
+ }
+ }
+
+ @Test
+ public void testGradientAtNonInitialPoint() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ LogLikelihoodFunction objectFunction = new LogLikelihoodFunction(testDataIndexer);
+ // when
+ double[] nonInitialPoint = new double[] { 2, 3, 2, 3, 3, 3, 2, 3, 2, 2 };
+ double[] gradientAtInitialPoint = objectFunction.gradientAt(nonInitialPoint);
+ double[] expectedGradient =
+ new double[] { 6.19368E-09, -3.04514E-16, 7.48224E-09, -7.15239E-09, 4.14274E-09,
+ -6.19368E-09, 0.0, -7.48225E-09, 7.15239E-09, -4.14274E-09};
+ // then
+ assertTrue(expectedGradient.length == gradientAtInitialPoint.length);
+ for (int i = 0; i < expectedGradient.length; i++) {
+ assertEquals(expectedGradient[i], gradientAtInitialPoint[i], TOLERANCE01);
+ }
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/LogLikelihoodFunctionTest.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java (added)
+++ opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+import static opennlp.PrepAttachDataUtil.createTrainingStream;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.List;
+
+import opennlp.model.AbstractModel;
+import opennlp.model.DataIndexer;
+import opennlp.model.Event;
+import opennlp.model.GenericModelReader;
+import opennlp.model.GenericModelWriter;
+import opennlp.model.MaxentModel;
+import opennlp.model.OnePassRealValueDataIndexer;
+import opennlp.model.RealValueFileEventStream;
+import opennlp.model.TwoPassDataIndexer;
+import opennlp.perceptron.PerceptronPrepAttachTest;
+
+import org.junit.Test;
+
+public class QNTrainerTest {
+ @Test
+ public void testTrainModelReturnsAQNModel() throws Exception {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ // when
+ QNModel trainedModel = new QNTrainer(false).trainModel(testDataIndexer);
+ // then
+ assertNotNull(trainedModel);
+ }
+
+ @Test
+ public void testInTinyDevSet() throws Exception {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ // when
+ QNModel trainedModel = new QNTrainer(15, true).trainModel(testDataIndexer);
+ String[] features2Classify = new String[] {"feature2","feature3", "feature3", "feature3","feature3", "feature3", "feature3","feature3", "feature3", "feature3","feature3", "feature3"};
+ double[] eval = trainedModel.eval(features2Classify);
+ // then
+ assertNotNull(eval);
+ }
+
+ @Test
+ public void testInBigDevSet() throws IOException {
+ QNModel trainedModel = new QNTrainer(10, 1000, true).trainModel(new TwoPassDataIndexer(createTrainingStream()));
+ // then
+ testModel(trainedModel);
+ }
+
+ @Test
+ public void testModel() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ // when
+ QNModel trainedModel = new QNTrainer(15, true).trainModel(testDataIndexer);
+
+ assertTrue(trainedModel.equals(trainedModel));
+ assertFalse(trainedModel.equals(null));
+ }
+
+ @Test
+ public void testSerdeModel() throws IOException {
+ // given
+ RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
+ DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
+ String modelFileName = "qn-test-model.bin";
+ // when
+ // QNModel trainedModel = new QNTrainer(5, 500, true).trainModel(new TwoPassDataIndexer(createTrainingStream()));
+ QNModel trainedModel = new QNTrainer(5, 700, true).trainModel(testDataIndexer);
+
+ GenericModelWriter modelWriter = new GenericModelWriter(trainedModel, new File(modelFileName));
+ modelWriter.persist();
+
+ GenericModelReader modelReader = new GenericModelReader(new File(modelFileName));
+ AbstractModel readModel = modelReader.getModel();
+ QNModel deserModel = (QNModel) readModel;
+
+ assertTrue(trainedModel.equals(deserModel));
+
+ String[] features2Classify = new String[] {"feature2","feature3", "feature3", "feature3","feature3", "feature3", "feature3","feature3", "feature3", "feature3","feature3", "feature3"};
+ double[] eval01 = trainedModel.eval(features2Classify);
+ double[] eval02 = deserModel.eval(features2Classify);
+
+ assertEquals(eval01.length, eval02.length);
+ for (int i = 0; i < eval01.length; i++) {
+ assertEquals(eval01[i], eval02[i], 0.00000001);
+ }
+ }
+
+ public static void testModel(MaxentModel model) throws IOException {
+ List<Event> devEvents = readPpaFile("devset");
+
+ int total = 0;
+ int correct = 0;
+ for (Event ev: devEvents) {
+ String targetLabel = ev.getOutcome();
+ double[] ocs = model.eval(ev.getContext());
+
+ int best = 0;
+ for (int i=1; i<ocs.length; i++)
+ if (ocs[i] > ocs[best])
+ best = i;
+ String predictedLabel = model.getOutcome(best);
+
+ if (targetLabel.equals(predictedLabel))
+ correct++;
+ total++;
+ }
+
+ double accuracy = correct/(double)total;
+ System.out.println("Accuracy on PPA devset: (" + correct + "/" + total + ") " + accuracy);
+ }
+
+ private static List<Event> readPpaFile(String filename) throws IOException {
+
+ List<Event> events = new ArrayList<Event>();
+
+ InputStream in = PerceptronPrepAttachTest.class.getResourceAsStream("/data/ppa/" +
+ filename);
+
+ try {
+ BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
+ String line;
+ while ((line = reader.readLine()) != null) {
+ String[] items = line.split("\\s+");
+ String label = items[5];
+ String[] context = { "verb=" + items[1], "noun=" + items[2],
+ "prep=" + items[3], "prep_obj=" + items[4] };
+ events.add(new Event(label, context));
+ }
+ } finally {
+ in.close();
+ }
+ return events;
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QNTrainerTest.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java?rev=1392944&view=auto
==============================================================================
--- opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java (added)
+++ opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java Tue Oct 2 14:52:34 2012
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.maxent.quasinewton;
+
+/**
+ * sample function for unit tests of LineSearch
+ */
+public class QuadraticFunction implements DifferentiableFunction {
+
+ public double valueAt(double[] x) {
+ // -(x-2)^2 + 4;
+ return (Math.pow(x[0] - 2.0, 2.0) * -1.0) + 4.0;
+ }
+
+ public double[] gradientAt(double[] x) {
+ return new double[] {(-2.0 * x[0]) + 4.0};
+ }
+
+ public int getDomainDimension() {
+ return 1;
+ }
+}
Propchange: opennlp/trunk/opennlp-maxent/src/test/java/opennlp/maxent/quasinewton/QuadraticFunction.java
------------------------------------------------------------------------------
svn:mime-type = text/plain