You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by oe...@apache.org on 2015/08/20 18:01:37 UTC

[2/2] [math] MATH-1258: check for equal array lengths in distance functions

MATH-1258: check for equal array lengths in distance functions

Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/7934bfea
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/7934bfea
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/7934bfea

Branch: refs/heads/MATH_3_X
Commit: 7934bfea106206d2840ba062eef105001601588a
Parents: 9cb16d5
Author: Otmar Ertl <ot...@gmail.com>
Authored: Thu Aug 20 17:46:54 2015 +0200
Committer: Otmar Ertl <ot...@gmail.com>
Committed: Thu Aug 20 17:55:42 2015 +0200

----------------------------------------------------------------------
 .../math3/ml/distance/CanberraDistance.java     |  6 +-
 .../math3/ml/distance/ChebyshevDistance.java    |  4 +-
 .../math3/ml/distance/DistanceMeasure.java      |  5 +-
 .../math3/ml/distance/EarthMoversDistance.java  |  6 +-
 .../math3/ml/distance/EuclideanDistance.java    |  4 +-
 .../math3/ml/distance/ManhattanDistance.java    |  4 +-
 .../apache/commons/math3/util/MathArrays.java   | 80 ++++++++++++++++----
 7 files changed, 89 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
index d997352..d467c3b 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
@@ -16,7 +16,9 @@
  */
 package org.apache.commons.math3.ml.distance;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
 
 /**
  * Calculates the Canberra distance between two points.
@@ -29,7 +31,9 @@ public class CanberraDistance implements DistanceMeasure {
     private static final long serialVersionUID = -6972277381587032228L;
 
     /** {@inheritDoc} */
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
+        MathArrays.checkEqualLength(a, b);
         double sum = 0;
         for (int i = 0; i < a.length; i++) {
             final double num = FastMath.abs(a[i] - b[i]);

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
index 9eecb15..05dccb5 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math3.ml.distance;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.util.MathArrays;
 
 /**
@@ -29,7 +30,8 @@ public class ChebyshevDistance implements DistanceMeasure {
     private static final long serialVersionUID = -4694868171115238296L;
 
     /** {@inheritDoc} */
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distanceInf(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
index 98bfc89..ff9c27f 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
@@ -18,6 +18,8 @@ package org.apache.commons.math3.ml.distance;
 
 import java.io.Serializable;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
+
 /**
  * Interface for distance measures of n-dimensional vectors.
  *
@@ -33,6 +35,7 @@ public interface DistanceMeasure extends Serializable {
      * @param a the first vector
      * @param b the second vector
      * @return the distance between the two vectors
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    double compute(double[] a, double[] b);
+    double compute(double[] a, double[] b) throws DimensionMismatchException;
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
index 13f2654..2518624 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
@@ -16,7 +16,9 @@
  */
 package org.apache.commons.math3.ml.distance;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
 
 /**
  * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions.
@@ -31,7 +33,9 @@ public class EarthMoversDistance implements DistanceMeasure {
     private static final long serialVersionUID = -5406732779747414922L;
 
     /** {@inheritDoc} */
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
+        MathArrays.checkEqualLength(a, b);
         double lastDistance = 0;
         double totalDistance = 0;
         for (int i = 0; i < a.length; i++) {

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
index 5d8029e..187badc 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math3.ml.distance;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.util.MathArrays;
 
 /**
@@ -29,7 +30,8 @@ public class EuclideanDistance implements DistanceMeasure {
     private static final long serialVersionUID = 1717556319784040040L;
 
     /** {@inheritDoc} */
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distance(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
index 9e898c1..2eebe1b 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math3.ml.distance;
 
+import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.util.MathArrays;
 
 /**
@@ -29,7 +30,8 @@ public class ManhattanDistance implements DistanceMeasure {
     private static final long serialVersionUID = -9108154600539125566L;
 
     /** {@inheritDoc} */
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distance1(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/util/MathArrays.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/util/MathArrays.java b/src/main/java/org/apache/commons/math3/util/MathArrays.java
index 5bf7890..46a8716 100644
--- a/src/main/java/org/apache/commons/math3/util/MathArrays.java
+++ b/src/main/java/org/apache/commons/math3/util/MathArrays.java
@@ -194,8 +194,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>1</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance1(double[] p1, double[] p2) {
+    public static double distance1(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double sum = 0;
         for (int i = 0; i < p1.length; i++) {
             sum += FastMath.abs(p1[i] - p2[i]);
@@ -209,13 +212,16 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>1</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static int distance1(int[] p1, int[] p2) {
-      int sum = 0;
-      for (int i = 0; i < p1.length; i++) {
-          sum += FastMath.abs(p1[i] - p2[i]);
-      }
-      return sum;
+    public static int distance1(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
+        int sum = 0;
+        for (int i = 0; i < p1.length; i++) {
+            sum += FastMath.abs(p1[i] - p2[i]);
+        }
+        return sum;
     }
 
     /**
@@ -224,8 +230,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>2</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance(double[] p1, double[] p2) {
+    public static double distance(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double sum = 0;
         for (int i = 0; i < p1.length; i++) {
             final double dp = p1[i] - p2[i];
@@ -251,8 +260,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>2</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance(int[] p1, int[] p2) {
+    public static double distance(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+      checkEqualLength(p1, p2);
       double sum = 0;
       for (int i = 0; i < p1.length; i++) {
           final double dp = p1[i] - p2[i];
@@ -267,8 +279,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>&infin;</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distanceInf(double[] p1, double[] p2) {
+    public static double distanceInf(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double max = 0;
         for (int i = 0; i < p1.length; i++) {
             max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -282,8 +297,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>&infin;</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static int distanceInf(int[] p1, int[] p2) {
+    public static int distanceInf(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         int max = 0;
         for (int i = 0; i < p1.length; i++) {
             max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -399,6 +417,42 @@ public class MathArrays {
         checkEqualLength(a, b, true);
     }
 
+
+    /**
+     * Check that both arrays have the same length.
+     *
+     * @param a Array.
+     * @param b Array.
+     * @param abort Whether to throw an exception if the check fails.
+     * @return {@code true} if the arrays have the same length.
+     * @throws DimensionMismatchException if the lengths differ and
+     * {@code abort} is {@code true}.
+     */
+    public static boolean checkEqualLength(int[] a,
+                                           int[] b,
+                                           boolean abort) {
+        if (a.length == b.length) {
+            return true;
+        } else {
+            if (abort) {
+                throw new DimensionMismatchException(a.length, b.length);
+            }
+            return false;
+        }
+    }
+
+    /**
+     * Check that both arrays have the same length.
+     *
+     * @param a Array.
+     * @param b Array.
+     * @throws DimensionMismatchException if the lengths differ.
+     */
+    public static void checkEqualLength(int[] a,
+                                        int[] b) {
+        checkEqualLength(a, b, true);
+    }
+    
     /**
      * Check that the given array is sorted.
      *
@@ -884,10 +938,8 @@ public class MathArrays {
      */
     public static double linearCombination(final double[] a, final double[] b)
         throws DimensionMismatchException {
+        checkEqualLength(a, b);
         final int len = a.length;
-        if (len != b.length) {
-            throw new DimensionMismatchException(len, b.length);
-        }
 
         if (len == 1) {
             // Revert to scalar multiplication.