You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/09 16:31:19 UTC
svn commit: r783014 - in /lucene/mahout/trunk/core: pom.xml
src/main/java/org/apache/mahout/matrix/AbstractVector.java
src/main/java/org/apache/mahout/matrix/Vector.java
src/test/java/org/apache/mahout/matrix/VectorTest.java
Author: gsingers
Date: Tue Jun 9 14:31:19 2009
New Revision: 783014
URL: http://svn.apache.org/viewvc?rev=783014&view=rev
Log:
MAHOUT-130: add normalization factors to Vector
Modified:
lucene/mahout/trunk/core/pom.xml
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
Modified: lucene/mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/pom.xml?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/pom.xml (original)
+++ lucene/mahout/trunk/core/pom.xml Tue Jun 9 14:31:19 2009
@@ -537,12 +537,12 @@
<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-analyzers</artifactId>
- <version>2.3.2</version>
+ <version>2.9-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.lucene</groupId>
<artifactId>lucene-core</artifactId>
- <version>2.3.2</version>
+ <version>2.9-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.mahout.commons</groupId>
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java Tue Jun 9 14:31:19 2009
@@ -116,6 +116,61 @@
return divide(divSq);
}
+ public Vector normalize(double power){
+ if (power < 0){
+ throw new IllegalArgumentException("Power must be >= 0");
+ }
+ double val = 0;
+ //we can special case certain powers
+ if (Double.isInfinite(power)) {
+ val = maxValue();
+ return divide(val);
+ } else if (power == 2) {
+ return normalize();
+ } else if (power == 1) {
+ val = zSum();
+ return divide(val);
+ } else if (power == 0) {
+ // this is the number of non-zero elements
+ for (int i = 0; i < cardinality(); i++) {
+ val += getQuick(i) != 0 ? 1 : 0;
+ }
+ return divide(val);
+ } else if (power > 0) {
+ for (int i = 0; i < cardinality(); i++) {
+ val += Math.pow(getQuick(i), power);
+ }
+ double divFactor = Math.pow(val, 1.0 / power);
+ return divide(divFactor);
+ } else {
+ throw new IllegalArgumentException("Unreachable");
+ }
+ }
+
+
+ @Override
+ public double maxValue() {
+ double result = Double.MIN_VALUE;
+ for (int i = 0; i < cardinality(); i++) {
+ result = Math.max(result, getQuick(i));
+ }
+ return result;
+ }
+
+ @Override
+ public int maxValueIndex() {
+ int result = -1;
+ double max = Double.MIN_VALUE;
+ for (int i = 0; i < cardinality(); i++) {
+ double tmp = getQuick(i);
+ if (tmp > max){
+ max = tmp;
+ result = i;
+ }
+ }
+ return result;
+ }
+
@Override
public Vector plus(double x) {
Vector result = copy();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java Tue Jun 9 14:31:19 2009
@@ -223,13 +223,40 @@
Vector minus(Vector x);
/**
- * Return a new matrix containing the normalized values of the recipient
+ * Return a new matrix containing the normalized (L_2 norm) values of the recipient
*
* @return a new Vector
*/
Vector normalize();
/**
+ * Return a new Vector containing the normalized (L_power norm) values of the recipient.
+ * <p/>
+ * See http://en.wikipedia.org/wiki/Lp_space
+ * <p/>
+ * Technically, when 0 < power < 1, we don't have a norm, just a metric, but we'll overload this here.
+ * <p/>
+ * Also supports power == 0 (number of non-zero elements) and power = {@link Double#POSITIVE_INFINITY} (max element). Again, see the Wikipedia page for more info
+ *
+ *
+ * @param power The power to use. Must be >= 0. May also be {@link Double#POSITIVE_INFINITY}. See the Wikipedia link for more on this.
+ * @return a new Vector
+ */
+ Vector normalize(double power);
+
+ /**
+ *
+ * @return The maximum value in the Vector
+ */
+ double maxValue();
+
+ /**
+ *
+ * @return The index of the maximum value
+ */
+ int maxValueIndex();
+
+ /**
* Return a new matrix containing the sum of each value of the recipient and
* the argument
*
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java Tue Jun 9 14:31:19 2009
@@ -17,11 +17,11 @@
package org.apache.mahout.matrix;
+import junit.framework.TestCase;
+
import java.util.Date;
import java.util.Random;
-import junit.framework.TestCase;
-
public class VectorTest extends TestCase {
public VectorTest(String s) {
@@ -45,6 +45,78 @@
assertEquals(result + " does not equal: " + 32, 32.0, result);
}
+ public void testNormalize() throws Exception {
+ SparseVector vec1 = new SparseVector(3);
+
+ vec1.setQuick(0, 1);
+ vec1.setQuick(1, 2);
+ vec1.setQuick(2, 3);
+ Vector norm = vec1.normalize();
+ assertTrue("norm1 is null and it shouldn't be", norm != null);
+ Vector expected = new SparseVector(3);
+
+ expected.setQuick(0, 0.2672612419124244);
+ expected.setQuick(1, 0.5345224838248488);
+ expected.setQuick(2, 0.8017837257372732);
+ assertTrue("norm is not equal to expected", norm.equals(expected));
+
+ norm = vec1.normalize(2);
+ assertTrue("norm is not equal to expected", norm.equals(expected));
+
+ norm = vec1.normalize(1);
+ expected.setQuick(0, 1.0/6);
+ expected.setQuick(1, 2.0/6);
+ expected.setQuick(2, 3.0/6);
+ assertTrue("norm is not equal to expected", norm.equals(expected));
+ norm = vec1.normalize(3);
+ //TODO this is not used
+ expected = vec1.times(vec1).times(vec1);
+
+ //double sum = expected.zSum();
+ //cube = Math.pow(sum, 1.0/3);
+ double cube = Math.pow(36, 1.0/3);
+ expected = vec1.divide(cube);
+
+ assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+ norm = vec1.normalize(Double.POSITIVE_INFINITY);
+ //The max is 3, so we divide by that.
+ expected.setQuick(0, 1.0/3);
+ expected.setQuick(1, 2.0/3);
+ expected.setQuick(2, 3.0/3);
+ assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+ norm = vec1.normalize(0);
+ //The max is 3, so we divide by that.
+ expected.setQuick(0, 1.0/3);
+ expected.setQuick(1, 2.0/3);
+ expected.setQuick(2, 3.0/3);
+ assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+ try {
+ vec1.normalize(-1);
+ assertTrue(false);
+ } catch (IllegalArgumentException e) {
+ //expected
+ }
+
+ }
+
+ public void testMax() throws Exception {
+ SparseVector vec1 = new SparseVector(3);
+
+ vec1.setQuick(0, 1);
+ vec1.setQuick(1, 3);
+ vec1.setQuick(2, 2);
+
+ double max = vec1.maxValue();
+ assertTrue(max + " does not equal: " + 3, max == 3);
+
+ int idx = vec1.maxValueIndex();
+ assertTrue(idx + " does not equal: " + 1, idx == 1);
+
+ }
+
public void testDenseVector() throws Exception {
DenseVector vec1 = new DenseVector(3);
DenseVector vec2 = new DenseVector(3);
@@ -75,12 +147,12 @@
}
public void testEnumeration() throws Exception {
- double[] apriori = { 0, 1, 2, 3, 4 };
+ double[] apriori = {0, 1, 2, 3, 4};
- doTestEnumeration(apriori, new VectorView(new DenseVector(new double[] {
- -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }), 2, 5));
+ doTestEnumeration(apriori, new VectorView(new DenseVector(new double[]{
+ -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 2, 5));
- doTestEnumeration(apriori, new DenseVector(new double[] { 0, 1, 2, 3, 4 }));
+ doTestEnumeration(apriori, new DenseVector(new double[]{0, 1, 2, 3, 4}));
SparseVector sparse = new SparseVector(5);
sparse.set(0, 0);
@@ -110,7 +182,7 @@
long tRef = t1 - t0;
assertTrue(tOpt < tRef);
System.out.println("testSparseVectorTimesX tRef=tOpt=" + (tRef - tOpt)
- + " ms for 10 iterations");
+ + " ms for 10 iterations");
for (int i = 0; i < 50000; i++)
assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
}
@@ -134,7 +206,7 @@
long tRef = t1 - t0;
assertTrue(tOpt < tRef);
System.out.println("testSparseVectorTimesV tRef=tOpt=" + (tRef - tOpt)
- + " ms for 10 iterations");
+ + " ms for 10 iterations");
for (int i = 0; i < 50000; i++)
assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
}