You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by ko...@apache.org on 2018/05/15 09:01:39 UTC
[opennlp] branch master updated: OPENNLP-1196: move ArrayMath to a
more general package and add its tests (#314)
This is an automated email from the ASF dual-hosted git repository.
koji pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/master by this push:
new de079d8 OPENNLP-1196: move ArrayMath to a more general package and add its tests (#314)
de079d8 is described below
commit de079d82172220c0b3f2a5e1f1cedf73f4c1ba5a
Author: Koji Sekiguchi <ko...@rondhuit.com>
AuthorDate: Tue May 15 18:01:37 2018 +0900
OPENNLP-1196: move ArrayMath to a more general package and add its tests (#314)
---
.../ml/{maxent/quasinewton => }/ArrayMath.java | 6 +-
.../tools/ml/maxent/quasinewton/ArrayMath.java | 97 +----------------
.../tools/ml/maxent/quasinewton/LineSearch.java | 2 +
.../ml/maxent/quasinewton/NegLogLikelihood.java | 1 +
.../quasinewton/ParallelNegLogLikelihood.java | 1 +
.../tools/ml/maxent/quasinewton/QNMinimizer.java | 1 +
.../tools/ml/maxent/quasinewton/QNModel.java | 1 +
.../tools/ml/maxent/quasinewton/QNTrainer.java | 3 +-
.../test/java/opennlp/tools/ml/ArrayMathTest.java | 115 +++++++++++++++++++++
9 files changed, 129 insertions(+), 98 deletions(-)
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java b/opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
similarity index 96%
copy from opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
copy to opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
index f8c2a31..3d40c69 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package opennlp.tools.ml.maxent.quasinewton;
+package opennlp.tools.ml;
import java.util.List;
@@ -77,7 +77,7 @@ public class ArrayMath {
}
public static double max(double[] x) {
- int maxIdx = maxIdx(x);
+ int maxIdx = argmax(x);
return x[maxIdx];
}
@@ -87,7 +87,7 @@ public class ArrayMath {
* @return index of the maximum element. Index of the first
* maximum element is returned if multiple maximums are found.
*/
- public static int maxIdx(double[] x) {
+ public static int argmax(double[] x) {
if (x == null || x.length == 0) {
throw new IllegalArgumentException("Vector x is null or empty");
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
index f8c2a31..ad8d9fc 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
@@ -17,69 +17,11 @@
package opennlp.tools.ml.maxent.quasinewton;
-import java.util.List;
-
/**
* Utility class for simple vector arithmetic.
*/
-public class ArrayMath {
-
- public static double innerProduct(double[] vecA, double[] vecB) {
- if (vecA == null || vecB == null || vecA.length != vecB.length)
- return Double.NaN;
-
- double product = 0.0;
- for (int i = 0; i < vecA.length; i++) {
- product += vecA[i] * vecB[i];
- }
- return product;
- }
-
- /**
- * L1-norm
- */
- public static double l1norm(double[] v) {
- double norm = 0;
- for (int i = 0; i < v.length; i++)
- norm += Math.abs(v[i]);
- return norm;
- }
-
- /**
- * L2-norm
- */
- public static double l2norm(double[] v) {
- return Math.sqrt(innerProduct(v, v));
- }
-
- /**
- * Inverse L2-norm
- */
- public static double invL2norm(double[] v) {
- return 1 / l2norm(v);
- }
-
- /**
- * Computes \log(\sum_{i=1}^n e^{x_i}) using a maximum-element trick
- * to avoid arithmetic overflow.
- *
- * @param x input vector
- * @return log-sum of exponentials of vector elements
- */
- public static double logSumOfExps(double[] x) {
- double max = max(x);
- double sum = 0.0;
- for (int i = 0; i < x.length; i++) {
- if (x[i] != Double.NEGATIVE_INFINITY)
- sum += Math.exp(x[i] - max);
- }
- return max + Math.log(sum);
- }
-
- public static double max(double[] x) {
- int maxIdx = maxIdx(x);
- return x[maxIdx];
- }
+@Deprecated
+public class ArrayMath extends opennlp.tools.ml.ArrayMath {
/**
* Find index of maximum element in the vector x
@@ -88,39 +30,6 @@ public class ArrayMath {
* maximum element is returned if multiple maximums are found.
*/
public static int maxIdx(double[] x) {
- if (x == null || x.length == 0) {
- throw new IllegalArgumentException("Vector x is null or empty");
- }
-
- int maxIdx = 0;
- for (int i = 1; i < x.length; i++) {
- if (x[maxIdx] < x[i])
- maxIdx = i;
- }
- return maxIdx;
- }
-
- // === Not really related to math ===
- /**
- * Convert a list of Double objects into an array of primitive doubles
- */
- public static double[] toDoubleArray(List<Double> list) {
- double[] arr = new double[list.size()];
- for (int i = 0; i < arr.length; i++) {
- arr[i] = list.get(i);
- }
- return arr;
- }
-
- /**
- * Convert a list of Integer objects into an array of primitive integers
- */
- public static int[] toIntArray(List<Integer> list) {
- int[] arr = new int[list.size()];
- for (int i = 0; i < arr.length; i++) {
- arr[i] = list.get(i);
- }
- return arr;
+ return opennlp.tools.ml.ArrayMath.argmax(x);
}
}
-
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
index 0798708..17ac500 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
@@ -17,6 +17,8 @@
package opennlp.tools.ml.maxent.quasinewton;
+import opennlp.tools.ml.ArrayMath;
+
/**
* Class that performs line search to find minimum
*/
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
index 7e5a243..0505e38 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
@@ -19,6 +19,7 @@ package opennlp.tools.ml.maxent.quasinewton;
import java.util.Arrays;
+import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
index d3e2d8c..36cacb3 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
@@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
+import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.DataIndexer;
/**
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
index b279acf..adc7f5b 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
@@ -17,6 +17,7 @@
package opennlp.tools.ml.maxent.quasinewton;
+import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.maxent.quasinewton.LineSearch.LineSearchResult;
/**
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
index 85f6dc3..030cb76 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
@@ -17,6 +17,7 @@
package opennlp.tools.ml.maxent.quasinewton;
+import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
index c1174d1..8faa193 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
@@ -23,6 +23,7 @@ import java.util.List;
import java.util.Map;
import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.maxent.quasinewton.QNMinimizer.Evaluator;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
@@ -247,7 +248,7 @@ public class QNTrainer extends AbstractEventTrainer {
double[] probs = new double[nOutcomes];
QNModel.eval(context, value, probs, nOutcomes, nPredLabels, parameters);
- int outcome = ArrayMath.maxIdx(probs);
+ int outcome = ArrayMath.argmax(probs);
if (outcome == outcomeList[ei]) {
nCorrect += nEventsSeen[ei];
}
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java
new file mode 100644
index 0000000..6b50aa5
--- /dev/null
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ArrayMathTest {
+
+ @Test
+ public void testInnerProductDoubleNaN() throws Exception {
+ Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(null, new double[]{0})));
+ Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(new double[]{0}, null)));
+ Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(new double[]{0, 1, 2}, new double[]{0, 1, 2, 3})));
+ }
+
+ @Test
+ public void testInnerProduct() throws Exception {
+ Assert.assertEquals(0, ArrayMath.innerProduct(new double[] {}, new double[] {}), 0);
+ Assert.assertEquals(-1, ArrayMath.innerProduct(new double[] {1}, new double[] {-1}), 0);
+ Assert.assertEquals(14, ArrayMath.innerProduct(new double[] {1, 2, 3}, new double[] {1, 2, 3}), 0);
+ }
+
+ @Test
+ public void testL1Norm() throws Exception {
+ Assert.assertEquals(0, ArrayMath.l1norm(new double[]{}), 0);
+ Assert.assertEquals(0, ArrayMath.l1norm(new double[] {0}), 0);
+ Assert.assertEquals(2, ArrayMath.l1norm(new double[] {1, -1}), 0);
+ Assert.assertEquals(55, ArrayMath.l1norm(new double[] {1, -2, 3, -4, 5, -6, 7, -8, 9, -10}), 0);
+ }
+
+ @Test
+ public void testL2Norm() throws Exception {
+ Assert.assertEquals(0, ArrayMath.l2norm(new double[] {}), 0);
+ Assert.assertEquals(0, ArrayMath.l2norm(new double[] {0}), 0);
+ Assert.assertEquals(1.41421, ArrayMath.l2norm(new double[] {1, -1}), 0.001);
+ Assert.assertEquals(0.54772, ArrayMath.l2norm(new double[] {0.1, -0.2, 0.3, -0.4}), 0.001);
+ }
+
+ @Test
+ public void testInvL2Norm() throws Exception {
+ Assert.assertEquals(0.70711, ArrayMath.invL2norm(new double[] {1, -1}), 0.001);
+ Assert.assertEquals(1.82575, ArrayMath.invL2norm(new double[] {0.1, -0.2, 0.3, -0.4}), 0.001);
+ }
+
+ @Test
+ public void testLogSumOfExps() throws Exception {
+ Assert.assertEquals(0, ArrayMath.logSumOfExps(new double[] {0}), 0);
+ Assert.assertEquals(1, ArrayMath.logSumOfExps(new double[] {1}), 0);
+ Assert.assertEquals(2.048587, ArrayMath.logSumOfExps(new double[] {-1, 2}), 0.001);
+ Assert.assertEquals(1.472216, ArrayMath.logSumOfExps(new double[] {-0.1, 0.2, -0.3, 0.4}), 0.001);
+ }
+
+ @Test
+ public void testMax() throws Exception {
+ Assert.assertEquals(0, ArrayMath.max(new double[] {0}), 0);
+ Assert.assertEquals(0, ArrayMath.max(new double[] {0, 0, 0}), 0);
+ Assert.assertEquals(2, ArrayMath.max(new double[] {0, 1, 2}), 0);
+ Assert.assertEquals(200, ArrayMath.max(new double[] {100, 200, 2}), 0);
+ Assert.assertEquals(300, ArrayMath.max(new double[] {100, 200, 300, -10, -20}), 0);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testArgmaxException1() throws Exception {
+ ArrayMath.argmax(null);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testArgmaxException2() throws Exception {
+ ArrayMath.argmax(new double[]{});
+ }
+
+ @Test
+ public void testArgmax() throws Exception {
+ Assert.assertEquals(0, ArrayMath.argmax(new double[] {0}));
+ Assert.assertEquals(0, ArrayMath.argmax(new double[] {0, 0, 0}));
+ Assert.assertEquals(2, ArrayMath.argmax(new double[] {0, 1, 2}));
+ Assert.assertEquals(1, ArrayMath.argmax(new double[] {100, 200, 2}));
+ Assert.assertEquals(2, ArrayMath.argmax(new double[] {100, 200, 300, -10, -20}));
+ }
+
+ @Test
+ public void testToDoubleArray() throws Exception {
+ Assert.assertEquals(0, ArrayMath.toDoubleArray(Collections.EMPTY_LIST).length);
+ Assert.assertArrayEquals(new double[] {0}, ArrayMath.toDoubleArray(Arrays.asList(0D)), 0);
+ Assert.assertArrayEquals(new double[] {0, 1, -2.5, -0.3, 4},
+ ArrayMath.toDoubleArray(Arrays.asList(0D, 1D, -2.5D, -0.3D, 4D)), 0);
+ }
+
+ @Test
+ public void testToIntArray() throws Exception {
+ Assert.assertEquals(0, ArrayMath.toIntArray(Collections.EMPTY_LIST).length);
+ Assert.assertArrayEquals(new int[] {0}, ArrayMath.toIntArray(Arrays.asList(0)));
+ Assert.assertArrayEquals(new int[] {0, 1, -2, -3, 4},
+ ArrayMath.toIntArray(Arrays.asList(0, 1, -2, -3, 4)));
+ }
+}
--
To stop receiving notification emails like this one, please contact
koji@apache.org.