You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/09 16:31:19 UTC

svn commit: r783014 - in /lucene/mahout/trunk/core: pom.xml src/main/java/org/apache/mahout/matrix/AbstractVector.java src/main/java/org/apache/mahout/matrix/Vector.java src/test/java/org/apache/mahout/matrix/VectorTest.java

Author: gsingers
Date: Tue Jun  9 14:31:19 2009
New Revision: 783014

URL: http://svn.apache.org/viewvc?rev=783014&view=rev
Log:
MAHOUT-130: add normalization factors to Vector

Modified:
    lucene/mahout/trunk/core/pom.xml
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java

Modified: lucene/mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/pom.xml?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/pom.xml (original)
+++ lucene/mahout/trunk/core/pom.xml Tue Jun  9 14:31:19 2009
@@ -537,12 +537,12 @@
     <dependency>
       <groupId>org.apache.lucene</groupId>
       <artifactId>lucene-analyzers</artifactId>
-      <version>2.3.2</version>
+      <version>2.9-SNAPSHOT</version>
     </dependency>
     <dependency>
       <groupId>org.apache.lucene</groupId>
       <artifactId>lucene-core</artifactId>
-      <version>2.3.2</version>
+      <version>2.9-SNAPSHOT</version>
     </dependency>
     <dependency>
       <groupId>org.apache.mahout.commons</groupId>

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java Tue Jun  9 14:31:19 2009
@@ -116,6 +116,61 @@
     return divide(divSq);
   }
 
+  public Vector normalize(double power){
+    if (power < 0){
+      throw new IllegalArgumentException("Power must be >= 0");
+    }
+    double val = 0;
+    //we can special case certain powers
+    if (Double.isInfinite(power)) {
+      val = maxValue();
+      return divide(val);
+    } else if (power == 2) {
+      return normalize();
+    } else if (power == 1) {
+      val = zSum();
+      return divide(val);
+    } else if (power == 0) {
+      // this is the number of non-zero elements
+      for (int i = 0; i < cardinality(); i++) {
+        val += getQuick(i) != 0 ? 1 : 0;
+      }
+      return divide(val);
+    } else if (power > 0) {
+      for (int i = 0; i < cardinality(); i++) {
+        val += Math.pow(getQuick(i), power);
+      }
+      double divFactor = Math.pow(val, 1.0 / power);
+      return divide(divFactor);
+    } else {
+      throw new IllegalArgumentException("Unreachable");
+    }
+  }
+
+
+  @Override
+  public double maxValue() {
+    double result = Double.MIN_VALUE;
+    for (int i = 0; i < cardinality(); i++) {
+      result = Math.max(result, getQuick(i));
+    }
+    return result;
+  }
+
+  @Override
+  public int maxValueIndex() {
+    int result = -1;
+    double max = Double.MIN_VALUE;
+    for (int i = 0; i < cardinality(); i++) {
+      double tmp = getQuick(i);
+      if (tmp > max){
+        max = tmp;
+        result = i;
+      }
+    }
+    return result;
+  }
+
   @Override
   public Vector plus(double x) {
     Vector result = copy();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Vector.java Tue Jun  9 14:31:19 2009
@@ -223,13 +223,40 @@
   Vector minus(Vector x);
 
   /**
-   * Return a new matrix containing the normalized values of the recipient
+   * Return a new matrix containing the normalized (L_2 norm) values of the recipient
    *
    * @return a new Vector
    */
   Vector normalize();
 
   /**
+   * Return a new Vector containing the normalized (L_power norm) values of the recipient.
+   * <p/>
+   * See http://en.wikipedia.org/wiki/Lp_space
+   * <p/>
+   * Technically, when 0 < power < 1, we don't have a norm, just a metric, but we'll overload this here.
+   * <p/>
+   * Also supports power == 0 (number of non-zero elements) and power = {@link Double#POSITIVE_INFINITY} (max element).  Again, see the Wikipedia page for more info
+   * 
+   *
+   * @param power The power to use.  Must be >= 0.  May also be {@link Double#POSITIVE_INFINITY}.  See the Wikipedia link for more on this.
+   * @return a new Vector
+   */
+  Vector normalize(double power);
+
+  /**
+   *
+   * @return The maximum value in the Vector
+   */
+  double maxValue();
+
+  /**
+   *
+   * @return The index of the maximum value
+   */
+  int maxValueIndex();
+
+  /**
    * Return a new matrix containing the sum of each value of the recipient and
    * the argument
    *

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java?rev=783014&r1=783013&r2=783014&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java Tue Jun  9 14:31:19 2009
@@ -17,11 +17,11 @@
 
 package org.apache.mahout.matrix;
 
+import junit.framework.TestCase;
+
 import java.util.Date;
 import java.util.Random;
 
-import junit.framework.TestCase;
-
 public class VectorTest extends TestCase {
 
   public VectorTest(String s) {
@@ -45,6 +45,78 @@
     assertEquals(result + " does not equal: " + 32, 32.0, result);
   }
 
+  public void testNormalize() throws Exception {
+    SparseVector vec1 = new SparseVector(3);
+
+    vec1.setQuick(0, 1);
+    vec1.setQuick(1, 2);
+    vec1.setQuick(2, 3);
+    Vector norm = vec1.normalize();
+    assertTrue("norm1 is null and it shouldn't be", norm != null);
+    Vector expected = new SparseVector(3);
+
+    expected.setQuick(0, 0.2672612419124244);
+    expected.setQuick(1, 0.5345224838248488);
+    expected.setQuick(2, 0.8017837257372732);
+    assertTrue("norm is not equal to expected", norm.equals(expected));
+
+    norm = vec1.normalize(2);
+    assertTrue("norm is not equal to expected", norm.equals(expected));
+
+    norm = vec1.normalize(1);
+    expected.setQuick(0, 1.0/6);
+    expected.setQuick(1, 2.0/6);
+    expected.setQuick(2, 3.0/6);
+    assertTrue("norm is not equal to expected", norm.equals(expected));
+    norm = vec1.normalize(3);
+    //TODO this is not used
+    expected = vec1.times(vec1).times(vec1);
+
+    //double sum = expected.zSum();
+    //cube = Math.pow(sum, 1.0/3);
+    double cube = Math.pow(36, 1.0/3);
+    expected = vec1.divide(cube);
+    
+    assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+    norm = vec1.normalize(Double.POSITIVE_INFINITY);
+    //The max is 3, so we divide by that.
+    expected.setQuick(0, 1.0/3);
+    expected.setQuick(1, 2.0/3);
+    expected.setQuick(2, 3.0/3);
+    assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+    norm = vec1.normalize(0);
+    //The max is 3, so we divide by that.
+    expected.setQuick(0, 1.0/3);
+    expected.setQuick(1, 2.0/3);
+    expected.setQuick(2, 3.0/3);
+    assertTrue("norm: " + norm.asFormatString() + " is not equal to expected: " + expected.asFormatString(), norm.equals(expected));
+
+    try {
+      vec1.normalize(-1);
+      assertTrue(false);
+    } catch (IllegalArgumentException e) {
+      //expected
+    }
+
+  }
+
+  public void testMax() throws Exception {
+    SparseVector vec1 = new SparseVector(3);
+
+    vec1.setQuick(0, 1);
+    vec1.setQuick(1, 3);
+    vec1.setQuick(2, 2);
+
+    double max = vec1.maxValue();
+    assertTrue(max + " does not equal: " + 3, max == 3);
+
+    int idx = vec1.maxValueIndex();
+    assertTrue(idx + " does not equal: " + 1, idx == 1);
+
+  }
+
   public void testDenseVector() throws Exception {
     DenseVector vec1 = new DenseVector(3);
     DenseVector vec2 = new DenseVector(3);
@@ -75,12 +147,12 @@
   }
 
   public void testEnumeration() throws Exception {
-    double[] apriori = { 0, 1, 2, 3, 4 };
+    double[] apriori = {0, 1, 2, 3, 4};
 
-    doTestEnumeration(apriori, new VectorView(new DenseVector(new double[] {
-        -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }), 2, 5));
+    doTestEnumeration(apriori, new VectorView(new DenseVector(new double[]{
+            -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 2, 5));
 
-    doTestEnumeration(apriori, new DenseVector(new double[] { 0, 1, 2, 3, 4 }));
+    doTestEnumeration(apriori, new DenseVector(new double[]{0, 1, 2, 3, 4}));
 
     SparseVector sparse = new SparseVector(5);
     sparse.set(0, 0);
@@ -110,7 +182,7 @@
     long tRef = t1 - t0;
     assertTrue(tOpt < tRef);
     System.out.println("testSparseVectorTimesX tRef=tOpt=" + (tRef - tOpt)
-        + " ms for 10 iterations");
+            + " ms for 10 iterations");
     for (int i = 0; i < 50000; i++)
       assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
   }
@@ -134,7 +206,7 @@
     long tRef = t1 - t0;
     assertTrue(tOpt < tRef);
     System.out.println("testSparseVectorTimesV tRef=tOpt=" + (tRef - tOpt)
-        + " ms for 10 iterations");
+            + " ms for 10 iterations");
     for (int i = 0; i < 50000; i++)
       assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
   }