You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by ko...@apache.org on 2018/05/15 09:01:39 UTC

[opennlp] branch master updated: OPENNLP-1196: move ArrayMath to a more general package and add its tests (#314)

This is an automated email from the ASF dual-hosted git repository.

koji pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/opennlp.git


The following commit(s) were added to refs/heads/master by this push:
     new de079d8  OPENNLP-1196: move ArrayMath to a more general package and add its tests (#314)
de079d8 is described below

commit de079d82172220c0b3f2a5e1f1cedf73f4c1ba5a
Author: Koji Sekiguchi <ko...@rondhuit.com>
AuthorDate: Tue May 15 18:01:37 2018 +0900

    OPENNLP-1196: move ArrayMath to a more general package and add its tests (#314)
---
 .../ml/{maxent/quasinewton => }/ArrayMath.java     |   6 +-
 .../tools/ml/maxent/quasinewton/ArrayMath.java     |  97 +----------------
 .../tools/ml/maxent/quasinewton/LineSearch.java    |   2 +
 .../ml/maxent/quasinewton/NegLogLikelihood.java    |   1 +
 .../quasinewton/ParallelNegLogLikelihood.java      |   1 +
 .../tools/ml/maxent/quasinewton/QNMinimizer.java   |   1 +
 .../tools/ml/maxent/quasinewton/QNModel.java       |   1 +
 .../tools/ml/maxent/quasinewton/QNTrainer.java     |   3 +-
 .../test/java/opennlp/tools/ml/ArrayMathTest.java  | 115 +++++++++++++++++++++
 9 files changed, 129 insertions(+), 98 deletions(-)

diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java b/opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
similarity index 96%
copy from opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
copy to opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
index f8c2a31..3d40c69 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/ArrayMath.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package opennlp.tools.ml.maxent.quasinewton;
+package opennlp.tools.ml;
 
 import java.util.List;
 
@@ -77,7 +77,7 @@ public class ArrayMath {
   }
 
   public static double max(double[] x) {
-    int maxIdx = maxIdx(x);
+    int maxIdx = argmax(x);
     return x[maxIdx];
   }
 
@@ -87,7 +87,7 @@ public class ArrayMath {
    * @return index of the maximum element. Index of the first
    *     maximum element is returned if multiple maximums are found.
    */
-  public static int maxIdx(double[] x) {
+  public static int argmax(double[] x) {
     if (x == null || x.length == 0) {
       throw new IllegalArgumentException("Vector x is null or empty");
     }
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
index f8c2a31..ad8d9fc 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ArrayMath.java
@@ -17,69 +17,11 @@
 
 package opennlp.tools.ml.maxent.quasinewton;
 
-import java.util.List;
-
 /**
  * Utility class for simple vector arithmetic.
  */
-public class ArrayMath {
-
-  public static double innerProduct(double[] vecA, double[] vecB) {
-    if (vecA == null || vecB == null || vecA.length != vecB.length)
-      return Double.NaN;
-
-    double product = 0.0;
-    for (int i = 0; i < vecA.length; i++) {
-      product += vecA[i] * vecB[i];
-    }
-    return product;
-  }
-
-  /**
-   * L1-norm
-   */
-  public static double l1norm(double[] v) {
-    double norm = 0;
-    for (int i = 0; i < v.length; i++)
-      norm += Math.abs(v[i]);
-    return norm;
-  }
-
-  /**
-   * L2-norm
-   */
-  public static double l2norm(double[] v) {
-    return Math.sqrt(innerProduct(v, v));
-  }
-
-  /**
-   * Inverse L2-norm
-   */
-  public static double invL2norm(double[] v) {
-    return 1 / l2norm(v);
-  }
-
-  /**
-   * Computes \log(\sum_{i=1}^n e^{x_i}) using a maximum-element trick
-   * to avoid arithmetic overflow.
-   *
-   * @param x input vector
-   * @return log-sum of exponentials of vector elements
-   */
-  public static double logSumOfExps(double[] x) {
-    double max = max(x);
-    double sum = 0.0;
-    for (int i = 0; i < x.length; i++) {
-      if (x[i] != Double.NEGATIVE_INFINITY)
-        sum += Math.exp(x[i] - max);
-    }
-    return max + Math.log(sum);
-  }
-
-  public static double max(double[] x) {
-    int maxIdx = maxIdx(x);
-    return x[maxIdx];
-  }
+@Deprecated
+public class ArrayMath extends opennlp.tools.ml.ArrayMath {
 
   /**
    * Find index of maximum element in the vector x
@@ -88,39 +30,6 @@ public class ArrayMath {
    *     maximum element is returned if multiple maximums are found.
    */
   public static int maxIdx(double[] x) {
-    if (x == null || x.length == 0) {
-      throw new IllegalArgumentException("Vector x is null or empty");
-    }
-
-    int maxIdx = 0;
-    for (int i = 1; i < x.length; i++) {
-      if (x[maxIdx] < x[i])
-        maxIdx = i;
-    }
-    return maxIdx;
-  }
-
-  // === Not really related to math ===
-  /**
-   * Convert a list of Double objects into an array of primitive doubles
-   */
-  public static double[] toDoubleArray(List<Double> list) {
-    double[] arr = new double[list.size()];
-    for (int i = 0; i < arr.length; i++) {
-      arr[i] = list.get(i);
-    }
-    return arr;
-  }
-
-  /**
-   *  Convert a list of Integer objects into an array of primitive integers
-   */
-  public static int[] toIntArray(List<Integer> list) {
-    int[] arr = new int[list.size()];
-    for (int i = 0; i < arr.length; i++) {
-      arr[i] = list.get(i);
-    }
-    return arr;
+    return opennlp.tools.ml.ArrayMath.argmax(x);
   }
 }
-
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
index 0798708..17ac500 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/LineSearch.java
@@ -17,6 +17,8 @@
 
 package opennlp.tools.ml.maxent.quasinewton;
 
+import opennlp.tools.ml.ArrayMath;
+
 /**
  * Class that performs line search to find minimum
  */
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
index 7e5a243..0505e38 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/NegLogLikelihood.java
@@ -19,6 +19,7 @@ package opennlp.tools.ml.maxent.quasinewton;
 
 import java.util.Arrays;
 
+import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.model.DataIndexer;
 import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
 
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
index d3e2d8c..36cacb3 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.java
@@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 
+import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.model.DataIndexer;
 
 /**
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
index b279acf..adc7f5b 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNMinimizer.java
@@ -17,6 +17,7 @@
 
 package opennlp.tools.ml.maxent.quasinewton;
 
+import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.maxent.quasinewton.LineSearch.LineSearchResult;
 
 /**
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
index 85f6dc3..030cb76 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNModel.java
@@ -17,6 +17,7 @@
 
 package opennlp.tools.ml.maxent.quasinewton;
 
+import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.Context;
 
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
index c1174d1..8faa193 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java
@@ -23,6 +23,7 @@ import java.util.List;
 import java.util.Map;
 
 import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.ArrayMath;
 import opennlp.tools.ml.maxent.quasinewton.QNMinimizer.Evaluator;
 import opennlp.tools.ml.model.AbstractModel;
 import opennlp.tools.ml.model.Context;
@@ -247,7 +248,7 @@ public class QNTrainer extends AbstractEventTrainer {
 
         double[] probs = new double[nOutcomes];
         QNModel.eval(context, value, probs, nOutcomes, nPredLabels, parameters);
-        int outcome = ArrayMath.maxIdx(probs);
+        int outcome = ArrayMath.argmax(probs);
         if (outcome == outcomeList[ei]) {
           nCorrect += nEventsSeen[ei];
         }
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java
new file mode 100644
index 0000000..6b50aa5
--- /dev/null
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/ArrayMathTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ArrayMathTest {
+
+  @Test
+  public void testInnerProductDoubleNaN() throws Exception {
+    Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(null, new double[]{0})));
+    Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(new double[]{0}, null)));
+    Assert.assertTrue(Double.isNaN(ArrayMath.innerProduct(new double[]{0, 1, 2}, new double[]{0, 1, 2, 3})));
+  }
+
+  @Test
+  public void testInnerProduct() throws Exception {
+    Assert.assertEquals(0, ArrayMath.innerProduct(new double[] {}, new double[] {}), 0);
+    Assert.assertEquals(-1, ArrayMath.innerProduct(new double[] {1}, new double[] {-1}), 0);
+    Assert.assertEquals(14, ArrayMath.innerProduct(new double[] {1, 2, 3}, new double[] {1, 2, 3}), 0);
+  }
+
+  @Test
+  public void testL1Norm() throws Exception {
+    Assert.assertEquals(0, ArrayMath.l1norm(new double[]{}), 0);
+    Assert.assertEquals(0, ArrayMath.l1norm(new double[] {0}), 0);
+    Assert.assertEquals(2, ArrayMath.l1norm(new double[] {1, -1}), 0);
+    Assert.assertEquals(55, ArrayMath.l1norm(new double[] {1, -2, 3, -4, 5, -6, 7, -8, 9, -10}), 0);
+  }
+
+  @Test
+  public void testL2Norm() throws Exception {
+    Assert.assertEquals(0, ArrayMath.l2norm(new double[] {}), 0);
+    Assert.assertEquals(0, ArrayMath.l2norm(new double[] {0}), 0);
+    Assert.assertEquals(1.41421, ArrayMath.l2norm(new double[] {1, -1}), 0.001);
+    Assert.assertEquals(0.54772, ArrayMath.l2norm(new double[] {0.1, -0.2, 0.3, -0.4}), 0.001);
+  }
+
+  @Test
+  public void testInvL2Norm() throws Exception {
+    Assert.assertEquals(0.70711, ArrayMath.invL2norm(new double[] {1, -1}), 0.001);
+    Assert.assertEquals(1.82575, ArrayMath.invL2norm(new double[] {0.1, -0.2, 0.3, -0.4}), 0.001);
+  }
+
+  @Test
+  public void testLogSumOfExps() throws Exception {
+    Assert.assertEquals(0, ArrayMath.logSumOfExps(new double[] {0}), 0);
+    Assert.assertEquals(1, ArrayMath.logSumOfExps(new double[] {1}), 0);
+    Assert.assertEquals(2.048587, ArrayMath.logSumOfExps(new double[] {-1, 2}), 0.001);
+    Assert.assertEquals(1.472216, ArrayMath.logSumOfExps(new double[] {-0.1, 0.2, -0.3, 0.4}), 0.001);
+  }
+
+  @Test
+  public void testMax() throws Exception {
+    Assert.assertEquals(0, ArrayMath.max(new double[] {0}), 0);
+    Assert.assertEquals(0, ArrayMath.max(new double[] {0, 0, 0}), 0);
+    Assert.assertEquals(2, ArrayMath.max(new double[] {0, 1, 2}), 0);
+    Assert.assertEquals(200, ArrayMath.max(new double[] {100, 200, 2}), 0);
+    Assert.assertEquals(300, ArrayMath.max(new double[] {100, 200, 300, -10, -20}), 0);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testArgmaxException1() throws Exception {
+    ArrayMath.argmax(null);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testArgmaxException2() throws Exception {
+    ArrayMath.argmax(new double[]{});
+  }
+
+  @Test
+  public void testArgmax() throws Exception {
+    Assert.assertEquals(0, ArrayMath.argmax(new double[] {0}));
+    Assert.assertEquals(0, ArrayMath.argmax(new double[] {0, 0, 0}));
+    Assert.assertEquals(2, ArrayMath.argmax(new double[] {0, 1, 2}));
+    Assert.assertEquals(1, ArrayMath.argmax(new double[] {100, 200, 2}));
+    Assert.assertEquals(2, ArrayMath.argmax(new double[] {100, 200, 300, -10, -20}));
+  }
+
+  @Test
+  public void testToDoubleArray() throws Exception {
+    Assert.assertEquals(0, ArrayMath.toDoubleArray(Collections.EMPTY_LIST).length);
+    Assert.assertArrayEquals(new double[] {0}, ArrayMath.toDoubleArray(Arrays.asList(0D)), 0);
+    Assert.assertArrayEquals(new double[] {0, 1, -2.5, -0.3, 4},
+        ArrayMath.toDoubleArray(Arrays.asList(0D, 1D, -2.5D, -0.3D, 4D)), 0);
+  }
+
+  @Test
+  public void testToIntArray() throws Exception {
+    Assert.assertEquals(0, ArrayMath.toIntArray(Collections.EMPTY_LIST).length);
+    Assert.assertArrayEquals(new int[] {0}, ArrayMath.toIntArray(Arrays.asList(0)));
+    Assert.assertArrayEquals(new int[] {0, 1, -2, -3, 4},
+        ArrayMath.toIntArray(Arrays.asList(0, 1, -2, -3, 4)));
+  }
+}

-- 
To stop receiving notification emails like this one, please contact
koji@apache.org.