You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by yx...@apache.org on 2013/07/06 04:12:01 UTC
svn commit: r1500189 - in /hama/trunk: ./
ml/src/main/java/org/apache/hama/ml/distance/
ml/src/main/java/org/apache/hama/ml/kmeans/
ml/src/main/java/org/apache/hama/ml/math/
ml/src/main/java/org/apache/hama/ml/regression/
ml/src/main/java/org/apache/ha...
Author: yxjiang
Date: Sat Jul 6 02:12:01 2013
New Revision: 1500189
URL: http://svn.apache.org/r1500189
Log:
HAMA-773: Matrix/Vector operation does not validate the input argument.
Modified:
hama/trunk/CHANGES.txt
hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Sat Jul 6 02:12:01 2013
@@ -12,6 +12,7 @@ Release 0.6.3 (unreleased changes)
IMPROVEMENTS
HAMA-765: Add apply method to Vector/Matrix (Yexi Jiang)
+ HAMA-773: Matrix/Vector operation does not validate the input argument (Yexi Jiang)
Release 0.6.2 - June 26, 2013
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java Sat Jul 6 02:12:01 2013
@@ -50,7 +50,7 @@ public final class CosineDistance implem
double lengthSquaredv1 = vec1.pow(2).sum();
double lengthSquaredv2 = vec2.pow(2).sum();
- double dotProduct = vec2.dot(vec1);
+ double dotProduct = vec2.dotUnsafe(vec1);
double denominator = Math.sqrt(lengthSquaredv1)
* Math.sqrt(lengthSquaredv2);
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java Sat Jul 6 02:12:01 2013
@@ -36,7 +36,7 @@ public final class EuclidianDistance imp
@Override
public double measureDistance(DoubleVector vec1, DoubleVector vec2) {
- return Math.sqrt(vec2.subtract(vec1).pow(2).sum());
+ return Math.sqrt(vec2.subtractUnsafe(vec1).pow(2).sum());
}
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java Sat Jul 6 02:12:01 2013
@@ -162,7 +162,7 @@ public final class KMeansBSP
if (oldCenter == null) {
msgCenters[msg.getCenterIndex()] = newCenter;
} else {
- msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter);
+ msgCenters[msg.getCenterIndex()] = oldCenter.addUnsafe(newCenter);
}
}
// divide by how often we globally summed vectors
@@ -177,7 +177,7 @@ public final class KMeansBSP
for (int i = 0; i < msgCenters.length; i++) {
final DoubleVector oldCenter = centers[i];
if (msgCenters[i] != null) {
- double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum();
+ double calculateError = oldCenter.subtractUnsafe(msgCenters[i]).abs().sum();
if (calculateError > 0.0d) {
centers[i] = msgCenters[i];
convergedCounter++;
@@ -241,7 +241,7 @@ public final class KMeansBSP
} else {
// add the vector to the center
newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter]
- .add(key);
+ .addUnsafe(key);
summationCount[lowestDistantCenter]++;
}
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java Sat Jul 6 02:12:01 2013
@@ -21,6 +21,8 @@ import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
+import com.google.common.base.Preconditions;
+
/**
* Dense double matrix implementation, internally uses two dimensional double
* arrays.
@@ -384,7 +386,7 @@ public final class DenseDoubleMatrix imp
* @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix)
*/
@Override
- public final DoubleMatrix multiply(DoubleMatrix other) {
+ public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) {
DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(),
other.getColumnCount());
@@ -412,7 +414,7 @@ public final class DenseDoubleMatrix imp
* )
*/
@Override
- public final DoubleMatrix multiplyElementWise(DoubleMatrix other) {
+ public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) {
DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows,
this.numColumns);
@@ -431,7 +433,7 @@ public final class DenseDoubleMatrix imp
* de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector)
*/
@Override
- public final DoubleVector multiplyVector(DoubleVector v) {
+ public final DoubleVector multiplyVectorUnsafe(DoubleVector v) {
DoubleVector vector = new DenseDoubleVector(this.getRowCount());
for (int row = 0; row < numRows; row++) {
double sum = 0.0d;
@@ -494,7 +496,7 @@ public final class DenseDoubleMatrix imp
* @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix)
*/
@Override
- public DoubleMatrix subtract(DoubleMatrix other) {
+ public DoubleMatrix subtractUnsafe(DoubleMatrix other) {
DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns);
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numColumns; j++) {
@@ -509,7 +511,7 @@ public final class DenseDoubleMatrix imp
* @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector)
*/
@Override
- public DenseDoubleMatrix subtract(DoubleVector vec) {
+ public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) {
DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(),
this.getColumnCount());
for (int i = 0; i < this.getColumnCount(); i++) {
@@ -523,7 +525,7 @@ public final class DenseDoubleMatrix imp
* @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleVector)
*/
@Override
- public DoubleMatrix divide(DoubleVector vec) {
+ public DoubleMatrix divideUnsafe(DoubleVector vec) {
DoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(),
this.getColumnCount());
for (int i = 0; i < this.getColumnCount(); i++) {
@@ -532,12 +534,22 @@ public final class DenseDoubleMatrix imp
return cop;
}
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public DoubleMatrix divide(DoubleVector vec) {
+ Preconditions.checkArgument(this.getColumnCount() == vec.getDimension(),
+ "Dimension mismatch.");
+ return this.divideUnsafe(vec);
+ }
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleMatrix)
*/
@Override
- public DoubleMatrix divide(DoubleMatrix other) {
+ public DoubleMatrix divideUnsafe(DoubleMatrix other) {
DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns);
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numColumns; j++) {
@@ -547,6 +559,13 @@ public final class DenseDoubleMatrix imp
return m;
}
+ @Override
+ public DoubleMatrix divide(DoubleMatrix other) {
+ Preconditions.checkArgument(this.getRowCount() == other.getRowCount()
+ && this.getColumnCount() == other.getColumnCount());
+ return divideUnsafe(other);
+ }
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleMatrix#divide(double)
@@ -775,7 +794,7 @@ public final class DenseDoubleMatrix imp
* Just a absolute error function.
*/
public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) {
- return a.subtract(b).sum();
+ return a.subtractUnsafe(b).sum();
}
@Override
@@ -795,20 +814,91 @@ public final class DenseDoubleMatrix imp
/**
* {@inheritDoc}
*/
- public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) {
- if (this.numRows != other.getRowCount()
- || this.numColumns != other.getColumnCount()) {
- throw new IllegalArgumentException(
- "Cannot apply double double function to matrices with different sizes.");
- }
-
+ public DoubleMatrix applyToElements(DoubleMatrix other,
+ DoubleDoubleFunction fun) {
+ Preconditions
+ .checkArgument(this.numRows == other.getRowCount()
+ && this.numColumns == other.getColumnCount(),
+ "Cannot apply double double function to matrices with different sizes.");
+
for (int r = 0; r < this.numRows; ++r) {
for (int c = 0; c < this.numColumns; ++c) {
this.set(r, c, fun.apply(this.get(r, c), other.get(r, c)));
}
}
-
+
return this;
}
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math
+ * .DoubleMatrix)
+ */
+ @Override
+ public DoubleMatrix multiply(DoubleMatrix other) {
+ Preconditions
+ .checkArgument(
+ this.numColumns == other.getRowCount(),
+ String
+ .format(
+ "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]",
+ this.numRows, this.numColumns, other.getRowCount(),
+ other.getColumnCount()));
+
+ return this.multiplyUnsafe(other);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache
+ * .hama.ml.math.DoubleMatrix)
+ */
+ @Override
+ public DoubleMatrix multiplyElementWise(DoubleMatrix other) {
+ Preconditions.checkArgument(this.numRows == other.getRowCount()
+ && this.numColumns == other.getColumnCount(),
+ "Matrices with different dimensions cannot be multiplied elementwise.");
+ return this.multiplyElementWiseUnsafe(other);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama
+ * .ml.math.DoubleVector)
+ */
+ @Override
+ public DoubleVector multiplyVector(DoubleVector v) {
+ Preconditions.checkArgument(this.numColumns == v.getDimension(),
+ "Dimension mismatch.");
+ return this.multiplyVectorUnsafe(v);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math.
+ * DoubleMatrix)
+ */
+ @Override
+ public DoubleMatrix subtract(DoubleMatrix other) {
+ Preconditions.checkArgument(this.numRows == other.getRowCount()
+ && this.numColumns == other.getColumnCount(), "Dimension mismatch.");
+ return subtractUnsafe(other);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math.
+ * DoubleVector)
+ */
+ @Override
+ public DoubleMatrix subtract(DoubleVector vec) {
+ Preconditions.checkArgument(this.numColumns == vec.getDimension(),
+ "Dimension mismatch.");
+ return null;
+ }
+
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java Sat Jul 6 02:12:01 2013
@@ -24,6 +24,7 @@ import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
+import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
/**
@@ -112,16 +113,17 @@ public final class DenseDoubleVector imp
}
/**
- * {@inheritDoc}}
+ * {@inheritDoc}
*/
@Override
- public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func) {
+ public DoubleVector applyToElements(DoubleVector other,
+ DoubleDoubleFunction func) {
for (int i = 0; i < vector.length; i++) {
this.vector[i] = func.apply(vector[i], other.get(i));
}
return this;
}
-
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function.
@@ -157,7 +159,7 @@ public final class DenseDoubleVector imp
* @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector)
*/
@Override
- public final DoubleVector add(DoubleVector v) {
+ public final DoubleVector addUnsafe(DoubleVector v) {
DenseDoubleVector newv = new DenseDoubleVector(v.getLength());
for (int i = 0; i < v.getLength(); i++) {
newv.set(i, this.get(i) + v.get(i));
@@ -183,7 +185,7 @@ public final class DenseDoubleVector imp
* @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector)
*/
@Override
- public final DoubleVector subtract(DoubleVector v) {
+ public final DoubleVector subtractUnsafe(DoubleVector v) {
DoubleVector newv = new DenseDoubleVector(v.getLength());
for (int i = 0; i < v.getLength(); i++) {
newv.set(i, this.get(i) - v.get(i));
@@ -235,7 +237,7 @@ public final class DenseDoubleVector imp
* @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector)
*/
@Override
- public DoubleVector multiply(DoubleVector vector) {
+ public DoubleVector multiplyUnsafe(DoubleVector vector) {
DoubleVector v = new DenseDoubleVector(this.getLength());
for (int i = 0; i < v.getLength(); i++) {
v.set(i, this.get(i) * vector.get(i));
@@ -338,10 +340,10 @@ public final class DenseDoubleVector imp
* @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector)
*/
@Override
- public double dot(DoubleVector s) {
+ public double dotUnsafe(DoubleVector vector) {
double dotProduct = 0.0d;
for (int i = 0; i < getLength(); i++) {
- dotProduct += this.get(i) * s.get(i);
+ dotProduct += this.get(i) * vector.get(i);
}
return dotProduct;
}
@@ -652,4 +654,54 @@ public final class DenseDoubleVector imp
return null;
}
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math.
+ * DoubleVector)
+ */
+ @Override
+ public DoubleVector add(DoubleVector vector) {
+ Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+ "Dimensions of two vectors do not equal.");
+ return this.addUnsafe(vector);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math
+ * .DoubleVector)
+ */
+ @Override
+ public DoubleVector subtract(DoubleVector vector) {
+ Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+ "Dimensions of two vectors do not equal.");
+ return this.subtractUnsafe(vector);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math
+ * .DoubleVector)
+ */
+ @Override
+ public DoubleVector multiply(DoubleVector vector) {
+ Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+ "Dimensions of two vectors do not equal.");
+ return this.multiplyUnsafe(vector);
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math.
+ * DoubleVector)
+ */
+ @Override
+ public double dot(DoubleVector vector) {
+ Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+ "Dimensions of two vectors do not equal.");
+ return this.dotUnsafe(vector);
+ }
+
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java Sat Jul 6 02:12:01 2013
@@ -80,18 +80,47 @@ public interface DoubleMatrix {
/**
* Multiplies this matrix with the given other matrix.
+ *
+ * @param other the other matrix.
+ * @return
+ */
+ public DoubleMatrix multiplyUnsafe(DoubleMatrix other);
+
+ /**
+ * Validates the input and multiplies this matrix with the given other matrix.
+ *
+ * @param other the other matrix.
+ * @return
*/
public DoubleMatrix multiply(DoubleMatrix other);
/**
* Multiplies this matrix per element with a given matrix.
*/
+ public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other);
+
+ /**
+ * Validates the input and multiplies this matrix per element with a given
+ * matrix.
+ *
+ * @param other the other matrix
+ * @return
+ */
public DoubleMatrix multiplyElementWise(DoubleMatrix other);
/**
* Multiplies this matrix with a given vector v. The returning vector contains
* the sum of the rows.
*/
+ public DoubleVector multiplyVectorUnsafe(DoubleVector v);
+
+ /**
+ * Multiplies this matrix with a given vector v. The returning vector contains
+ * the sum of the rows.
+ *
+ * @param v the vector
+ * @return
+ */
public DoubleVector multiplyVector(DoubleVector v);
/**
@@ -114,23 +143,58 @@ public interface DoubleMatrix {
/**
* Subtracts this matrix by the given other matrix.
*/
+ public DoubleMatrix subtractUnsafe(DoubleMatrix other);
+
+ /**
+ * Validates the input and subtracts this matrix by the given other matrix.
+ *
+ * @param other
+ * @return
+ */
public DoubleMatrix subtract(DoubleMatrix other);
/**
* Subtracts each element in a column by the related element in the given
* vector.
*/
+ public DoubleMatrix subtractUnsafe(DoubleVector vec);
+
+ /**
+ * Validates and subtracts each element in a column by the related element in
+ * the given vector.
+ *
+ * @param vec
+ * @return
+ */
public DoubleMatrix subtract(DoubleVector vec);
/**
* Divides each element in a column by the related element in the given
* vector.
*/
+ public DoubleMatrix divideUnsafe(DoubleVector vec);
+
+ /**
+ * Validates and divides each element in a column by the related element in
+ * the given vector.
+ *
+ * @param vec
+ * @return
+ */
public DoubleMatrix divide(DoubleVector vec);
/**
* Divides this matrix by the given other matrix. (Per element division).
*/
+ public DoubleMatrix divideUnsafe(DoubleMatrix other);
+
+ /**
+ * Validates and divides this matrix by the given other matrix. (Per element
+ * division).
+ *
+ * @param other
+ * @return
+ */
public DoubleMatrix divide(DoubleMatrix other);
/**
@@ -203,6 +267,7 @@ public interface DoubleMatrix {
* @param fun The function that takes two arguments.
* @return The matrix itself, supply for chain operation.
*/
- public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun);
+ public DoubleMatrix applyToElements(DoubleMatrix other,
+ DoubleDoubleFunction fun);
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java Sat Jul 6 02:12:01 2013
@@ -58,7 +58,7 @@ public interface DoubleVector {
* @param value the value at the index of the vector to set.
*/
public void set(int index, double value);
-
+
/**
* Apply a given {@link DoubleVectorFunction} to this vector and return a new
* one.
@@ -68,7 +68,7 @@ public interface DoubleVector {
*/
@Deprecated
public DoubleVector apply(DoubleVectorFunction func);
-
+
/**
* Apply a given {@link DoubleDoubleVectorFunction} to this vector and the
* other given vector.
@@ -97,15 +97,24 @@ public interface DoubleVector {
* @param func the function to apply on this and the other vector.
* @return a new vector with the result of the function of the two vectors.
*/
- public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func);
+ public DoubleVector applyToElements(DoubleVector other,
+ DoubleDoubleFunction func);
/**
* Adds the given {@link DoubleVector} to this vector.
*
- * @param v the other vector.
+ * @param vector the other vector.
+ * @return a new vector with the sum of both vectors at each element index.
+ */
+ public DoubleVector addUnsafe(DoubleVector vector);
+
+ /**
+ * Validates the input and adds the given {@link DoubleVector} to this vector.
+ *
+ * @param vector the other vector.
* @return a new vector with the sum of both vectors at each element index.
*/
- public DoubleVector add(DoubleVector v);
+ public DoubleVector add(DoubleVector vector);
/**
* Adds the given scalar to this vector.
@@ -118,10 +127,19 @@ public interface DoubleVector {
/**
* Subtracts this vector by the given {@link DoubleVector}.
*
- * @param v the other vector.
+ * @param vector the other vector.
* @return a new vector with the difference of both vectors.
*/
- public DoubleVector subtract(DoubleVector v);
+ public DoubleVector subtractUnsafe(DoubleVector vector);
+
+ /**
+ * Validates the input and subtracts this vector by the given
+ * {@link DoubleVector}.
+ *
+ * @param vector the other vector.
+ * @return a new vector with the difference of both vectors.
+ */
+ public DoubleVector subtract(DoubleVector vector);
/**
* Subtracts the given scalar to this vector. (vector - scalar).
@@ -153,6 +171,15 @@ public interface DoubleVector {
* @param vector the other vector.
* @return a new vector with the result of the operation.
*/
+ public DoubleVector multiplyUnsafe(DoubleVector vector);
+
+ /**
+ * Validates the input and multiplies the given {@link DoubleVector} with this
+ * vector.
+ *
+ * @param vector the other vector.
+ * @return a new vector with the result of the operation.
+ */
public DoubleVector multiply(DoubleVector vector);
/**
@@ -201,10 +228,19 @@ public interface DoubleVector {
/**
* Calculates the dot product between this vector and the given vector.
*
- * @param s the given vector s.
+ * @param vector the given vector.
+ * @return the dot product as a double.
+ */
+ public double dotUnsafe(DoubleVector vector);
+
+ /**
+ * Validates the input and calculates the dot product between this vector and
+ * the given vector.
+ *
+ * @param vector the given vector.
* @return the dot product as a double.
*/
- public double dot(DoubleVector s);
+ public double dot(DoubleVector vector);
/**
* Slices this vector from index 0 to the given length.
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java Sat Jul 6 02:12:01 2013
@@ -38,7 +38,7 @@ public class LinearRegressionModel imple
@Override
public double applyHypothesis(DoubleVector theta, DoubleVector x) {
- return theta.dot(x);
+ return theta.dotUnsafe(x);
}
@Override
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java Sat Jul 6 02:12:01 2013
@@ -53,7 +53,7 @@ public class LogisticRegressionModel imp
DoubleVector x) {
return BigDecimal.valueOf(1).divide(
BigDecimal.valueOf(1d).add(
- BigDecimal.valueOf(Math.exp(-1d * theta.dot(x)))),
+ BigDecimal.valueOf(Math.exp(-1d * theta.dotUnsafe(x)))),
MathContext.DECIMAL128);
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java Sat Jul 6 02:12:01 2013
@@ -119,7 +119,7 @@ public final class VectorWritable implem
}
public static int compareVector(DoubleVector a, DoubleVector o) {
- DoubleVector subtract = a.subtract(o);
+ DoubleVector subtract = a.subtractUnsafe(o);
return (int) subtract.sum();
}
Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java Sat Jul 6 02:12:01 2013
@@ -19,6 +19,8 @@ package org.apache.hama.ml.math;
import static org.junit.Assert.assertArrayEquals;
+import java.util.Arrays;
+
import org.junit.Test;
/**
@@ -57,12 +59,14 @@ public class TestDenseDoubleMatrix {
@Test
public void testDoubleDoubleFunction() {
double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
- double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } };
- double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}};
+ double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 },
+ { 8, 9, 10 } };
+ double[][] result = new double[][] { { 3, 5, 7 }, { 9, 11, 13 },
+ { 15, 17, 19 } };
DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1);
DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2);
-
+
mat1.applyToElements(mat2, new DoubleDoubleFunction() {
@Override
@@ -83,4 +87,153 @@ public class TestDenseDoubleMatrix {
}
}
+ @Test
+ public void testMultiplyNormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 }, { 2, 1 } };
+ double[][] expMat = new double[][] { { 20, 14 }, { 56, 41 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ DoubleMatrix actMatrix = matrix1.multiply(matrix2);
+ for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+ assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+ 0.000001);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testMultiplyAbnormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ matrix1.multiply(matrix2);
+ }
+
+ @Test
+ public void testMultiplyElementWiseNormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } };
+ double[][] expMat = new double[][] { { 6, 10, 12 }, { 12, 10, 6 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ DoubleMatrix actMatrix = matrix1.multiplyElementWise(matrix2);
+ for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+ assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+ 0.000001);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testMultiplyElementWiseAbnormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ matrix1.multiplyElementWise(matrix2);
+ }
+
+ @Test
+ public void testMultiplyVectorNormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[] mat2 = new double[] { 6, 5, 4 };
+ double[] expVec = new double[] { 28, 73 };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleVector vector2 = new DenseDoubleVector(mat2);
+ DoubleVector actVec = matrix1.multiplyVector(vector2);
+ assertArrayEquals(expVec, actVec.toArray(), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testMultiplyVectorAbnormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[] vec2 = new double[] { 6, 5 };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleVector vector2 = new DenseDoubleVector(vec2);
+ matrix1.multiplyVector(vector2);
+ }
+
+ @Test
+ public void testSubtractNormal() {
+ double[][] mat1 = new double[][] {
+ {1, 2, 3},
+ {4, 5, 6}
+ };
+ double[][] mat2 = new double[][] {
+ {6, 5, 4},
+ {3, 2, 1}
+ };
+ double[][] expMat = new double[][] {
+ {-5, -3, -1},
+ {1, 3, 5}
+ };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ DoubleMatrix actMatrix = matrix1.subtract(matrix2);
+ for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+ assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), 0.000001);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testSubtractAbnormal() {
+ double[][] mat1 = new double[][] {
+ {1, 2, 3},
+ {4, 5, 6}
+ };
+ double[][] mat2 = new double[][] {
+ {6, 5},
+ {4, 3}
+ };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ matrix1.subtract(matrix2);
+ }
+
+ @Test
+ public void testDivideVectorNormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[] mat2 = new double[] { 6, 5, 4 };
+ double[][] expVec = new double[][] { {1.0 / 6, 2.0 / 5, 3.0 / 4}, {4.0 / 6, 5.0 / 5, 6.0 / 4} };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleVector vector2 = new DenseDoubleVector(mat2);
+ DoubleMatrix expMat = new DenseDoubleMatrix(expVec);
+ DoubleMatrix actMat = matrix1.divide(vector2);
+ for (int r = 0; r < actMat.getRowCount(); ++r) {
+ assertArrayEquals(expMat.getRowVector(r).toArray(), actMat.getRowVector(r).toArray(), 0.000001);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testDivideVectorAbnormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[] vec2 = new double[] { 6, 5 };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleVector vector2 = new DenseDoubleVector(vec2);
+ matrix1.divide(vector2);
+ }
+
+ @Test
+ public void testDivideNormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } };
+ double[][] expMat = new double[][] { { 1.0 / 6, 2.0 / 5, 3.0 / 4 }, { 4.0 / 3, 5.0 / 2, 6.0 / 1 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ DoubleMatrix actMatrix = matrix1.divide(matrix2);
+ for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+ assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+ 0.000001);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testDivideAbnormal() {
+ double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+ double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+ DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+ DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+ matrix1.divide(matrix2);
+ }
+
}
Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java Sat Jul 6 02:12:01 2013
@@ -18,8 +18,11 @@
package org.apache.hama.ml.math;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
/**
* Testcase for {@link DenseDoubleVector}
@@ -77,4 +80,79 @@ public class TestDenseDoubleVector {
assertArrayEquals(result, vec1.toArray(), 0.0001);
}
+
+ @Test
+ public void testAddNormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5, 6};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ double[] arrExp = new double[] {5, 7, 9};
+ assertArrayEquals(arrExp, vec1.add(vec2).toArray(), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testAddAbnormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ vec1.add(vec2);
+ }
+
+ @Test
+ public void testSubtractNormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5, 6};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ double[] arrExp = new double[] {-3, -3, -3};
+ assertArrayEquals(arrExp, vec1.subtract(vec2).toArray(), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testSubtractAbnormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ vec1.subtract(vec2);
+ }
+
+ @Test
+ public void testMultiplyNormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5, 6};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ double[] arrExp = new double[] {4, 10, 18};
+ assertArrayEquals(arrExp, vec1.multiply(vec2).toArray(), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testMultiplyAbnormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ vec1.multiply(vec2);
+ }
+
+ @Test
+ public void testDotNormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5, 6};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ assertEquals(32.0, vec1.dot(vec2), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testDotAbnormal() {
+ double[] arr1 = new double[] {1, 2, 3};
+ double[] arr2 = new double[] {4, 5};
+ DoubleVector vec1 = new DenseDoubleVector(arr1);
+ DoubleVector vec2 = new DenseDoubleVector(arr2);
+ vec1.add(vec2);
+ }
}