You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by oe...@apache.org on 2015/08/20 18:01:37 UTC
[2/2] [math] MATH-1258: check for equal array lengths in distance
functions
MATH-1258: check for equal array lengths in distance functions
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/7934bfea
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/7934bfea
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/7934bfea
Branch: refs/heads/MATH_3_X
Commit: 7934bfea106206d2840ba062eef105001601588a
Parents: 9cb16d5
Author: Otmar Ertl <ot...@gmail.com>
Authored: Thu Aug 20 17:46:54 2015 +0200
Committer: Otmar Ertl <ot...@gmail.com>
Committed: Thu Aug 20 17:55:42 2015 +0200
----------------------------------------------------------------------
.../math3/ml/distance/CanberraDistance.java | 6 +-
.../math3/ml/distance/ChebyshevDistance.java | 4 +-
.../math3/ml/distance/DistanceMeasure.java | 5 +-
.../math3/ml/distance/EarthMoversDistance.java | 6 +-
.../math3/ml/distance/EuclideanDistance.java | 4 +-
.../math3/ml/distance/ManhattanDistance.java | 4 +-
.../apache/commons/math3/util/MathArrays.java | 80 ++++++++++++++++----
7 files changed, 89 insertions(+), 20 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
index d997352..d467c3b 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java
@@ -16,7 +16,9 @@
*/
package org.apache.commons.math3.ml.distance;
+import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
/**
* Calculates the Canberra distance between two points.
@@ -29,7 +31,9 @@ public class CanberraDistance implements DistanceMeasure {
private static final long serialVersionUID = -6972277381587032228L;
/** {@inheritDoc} */
- public double compute(double[] a, double[] b) {
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ MathArrays.checkEqualLength(a, b);
double sum = 0;
for (int i = 0; i < a.length; i++) {
final double num = FastMath.abs(a[i] - b[i]);
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
index 9eecb15..05dccb5 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java
@@ -16,6 +16,7 @@
*/
package org.apache.commons.math3.ml.distance;
+import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.MathArrays;
/**
@@ -29,7 +30,8 @@ public class ChebyshevDistance implements DistanceMeasure {
private static final long serialVersionUID = -4694868171115238296L;
/** {@inheritDoc} */
- public double compute(double[] a, double[] b) {
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
return MathArrays.distanceInf(a, b);
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
index 98bfc89..ff9c27f 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java
@@ -18,6 +18,8 @@ package org.apache.commons.math3.ml.distance;
import java.io.Serializable;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+
/**
* Interface for distance measures of n-dimensional vectors.
*
@@ -33,6 +35,7 @@ public interface DistanceMeasure extends Serializable {
* @param a the first vector
* @param b the second vector
* @return the distance between the two vectors
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- double compute(double[] a, double[] b);
+ double compute(double[] a, double[] b) throws DimensionMismatchException;
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
index 13f2654..2518624 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java
@@ -16,7 +16,9 @@
*/
package org.apache.commons.math3.ml.distance;
+import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
/**
* Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions.
@@ -31,7 +33,9 @@ public class EarthMoversDistance implements DistanceMeasure {
private static final long serialVersionUID = -5406732779747414922L;
/** {@inheritDoc} */
- public double compute(double[] a, double[] b) {
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
+ MathArrays.checkEqualLength(a, b);
double lastDistance = 0;
double totalDistance = 0;
for (int i = 0; i < a.length; i++) {
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
index 5d8029e..187badc 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java
@@ -16,6 +16,7 @@
*/
package org.apache.commons.math3.ml.distance;
+import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.MathArrays;
/**
@@ -29,7 +30,8 @@ public class EuclideanDistance implements DistanceMeasure {
private static final long serialVersionUID = 1717556319784040040L;
/** {@inheritDoc} */
- public double compute(double[] a, double[] b) {
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
return MathArrays.distance(a, b);
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
index 9e898c1..2eebe1b 100644
--- a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
+++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java
@@ -16,6 +16,7 @@
*/
package org.apache.commons.math3.ml.distance;
+import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.util.MathArrays;
/**
@@ -29,7 +30,8 @@ public class ManhattanDistance implements DistanceMeasure {
private static final long serialVersionUID = -9108154600539125566L;
/** {@inheritDoc} */
- public double compute(double[] a, double[] b) {
+ public double compute(double[] a, double[] b)
+ throws DimensionMismatchException {
return MathArrays.distance1(a, b);
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/util/MathArrays.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/util/MathArrays.java b/src/main/java/org/apache/commons/math3/util/MathArrays.java
index 5bf7890..46a8716 100644
--- a/src/main/java/org/apache/commons/math3/util/MathArrays.java
+++ b/src/main/java/org/apache/commons/math3/util/MathArrays.java
@@ -194,8 +194,11 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>1</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static double distance1(double[] p1, double[] p2) {
+ public static double distance1(double[] p1, double[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
double sum = 0;
for (int i = 0; i < p1.length; i++) {
sum += FastMath.abs(p1[i] - p2[i]);
@@ -209,13 +212,16 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>1</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static int distance1(int[] p1, int[] p2) {
- int sum = 0;
- for (int i = 0; i < p1.length; i++) {
- sum += FastMath.abs(p1[i] - p2[i]);
- }
- return sum;
+ public static int distance1(int[] p1, int[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
+ int sum = 0;
+ for (int i = 0; i < p1.length; i++) {
+ sum += FastMath.abs(p1[i] - p2[i]);
+ }
+ return sum;
}
/**
@@ -224,8 +230,11 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>2</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static double distance(double[] p1, double[] p2) {
+ public static double distance(double[] p1, double[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
double sum = 0;
for (int i = 0; i < p1.length; i++) {
final double dp = p1[i] - p2[i];
@@ -251,8 +260,11 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>2</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static double distance(int[] p1, int[] p2) {
+ public static double distance(int[] p1, int[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
double sum = 0;
for (int i = 0; i < p1.length; i++) {
final double dp = p1[i] - p2[i];
@@ -267,8 +279,11 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>∞</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static double distanceInf(double[] p1, double[] p2) {
+ public static double distanceInf(double[] p1, double[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
double max = 0;
for (int i = 0; i < p1.length; i++) {
max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -282,8 +297,11 @@ public class MathArrays {
* @param p1 the first point
* @param p2 the second point
* @return the L<sub>∞</sub> distance between the two points
+ * @throws DimensionMismatchException if the array lengths differ.
*/
- public static int distanceInf(int[] p1, int[] p2) {
+ public static int distanceInf(int[] p1, int[] p2)
+ throws DimensionMismatchException {
+ checkEqualLength(p1, p2);
int max = 0;
for (int i = 0; i < p1.length; i++) {
max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -399,6 +417,42 @@ public class MathArrays {
checkEqualLength(a, b, true);
}
+
+ /**
+ * Check that both arrays have the same length.
+ *
+ * @param a Array.
+ * @param b Array.
+ * @param abort Whether to throw an exception if the check fails.
+ * @return {@code true} if the arrays have the same length.
+ * @throws DimensionMismatchException if the lengths differ and
+ * {@code abort} is {@code true}.
+ */
+ public static boolean checkEqualLength(int[] a,
+ int[] b,
+ boolean abort) {
+ if (a.length == b.length) {
+ return true;
+ } else {
+ if (abort) {
+ throw new DimensionMismatchException(a.length, b.length);
+ }
+ return false;
+ }
+ }
+
+ /**
+ * Check that both arrays have the same length.
+ *
+ * @param a Array.
+ * @param b Array.
+ * @throws DimensionMismatchException if the lengths differ.
+ */
+ public static void checkEqualLength(int[] a,
+ int[] b) {
+ checkEqualLength(a, b, true);
+ }
+
/**
* Check that the given array is sorted.
*
@@ -884,10 +938,8 @@ public class MathArrays {
*/
public static double linearCombination(final double[] a, final double[] b)
throws DimensionMismatchException {
+ checkEqualLength(a, b);
final int len = a.length;
- if (len != b.length) {
- throw new DimensionMismatchException(len, b.length);
- }
if (len == 1) {
// Revert to scalar multiplication.