You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2016/03/08 21:33:16 UTC

mahout git commit: MAHOUT-1640:Better collections would significantly improve vector-operation speed, closes apache/mahout#81

Repository: mahout
Updated Branches:
  refs/heads/master a3cdff6d6 -> 284651dc0


MAHOUT-1640:Better collections would significantly improve vector-operation speed, closes apache/mahout#81


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/284651dc
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/284651dc
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/284651dc

Branch: refs/heads/master
Commit: 284651dc05c0c57ba094ba738a20db6cbe3bbcd7
Parents: a3cdff6
Author: smarthi <sm...@apache.org>
Authored: Tue Mar 8 15:32:57 2016 -0500
Committer: smarthi <sm...@apache.org>
Committed: Tue Mar 8 15:32:57 2016 -0500

----------------------------------------------------------------------
 LICENSE.txt                                     |   2 +-
 math/pom.xml                                    |   6 +
 .../mahout/math/RandomAccessSparseVector.java   | 154 +++++++++----------
 .../math/TestRandomAccessSparseVector.java      |   2 +-
 .../java/org/apache/mahout/math/VectorTest.java |  13 +-
 5 files changed, 91 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/LICENSE.txt
----------------------------------------------------------------------
diff --git a/LICENSE.txt b/LICENSE.txt
index dc49724..92ceb9b 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -2,7 +2,7 @@
 The following license applies to software from the
 Apache Software Foundation.
 It also applies to software from the Uncommons Watchmaker and Math
-projects, Google Guava software, and MongoDB.org driver software
+projects, Google Guava software, MongoDB.org driver software and fastutil.
 --------------------------------------------------------------------------
 
                                  Apache License

http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/pom.xml
----------------------------------------------------------------------
diff --git a/math/pom.xml b/math/pom.xml
index 0b946d5..fda8f18 100644
--- a/math/pom.xml
+++ b/math/pom.xml
@@ -146,6 +146,12 @@
     </dependency>
 
     <dependency>
+       <groupId>it.unimi.dsi</groupId>
+       <artifactId>fastutil</artifactId>
+       <version>7.0.11</version>
+    </dependency>
+
+    <dependency>
       <groupId>org.slf4j</groupId>
       <artifactId>slf4j-api</artifactId>
     </dependency>

http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
index 3efac7e..9316915 100644
--- a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
+++ b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
@@ -17,21 +17,23 @@
 
 package org.apache.mahout.math;
 
+import it.unimi.dsi.fastutil.doubles.DoubleIterator;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap.Entry;
+import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+
 import java.util.Iterator;
 import java.util.NoSuchElementException;
 
-import org.apache.mahout.math.list.DoubleArrayList;
-import org.apache.mahout.math.map.OpenIntDoubleHashMap;
-import org.apache.mahout.math.map.OpenIntDoubleHashMap.MapElement;
 import org.apache.mahout.math.set.AbstractSet;
 
-
 /** Implements vector that only stores non-zero doubles */
 public class RandomAccessSparseVector extends AbstractVector {
 
   private static final int INITIAL_CAPACITY = 11;
 
-  private OpenIntDoubleHashMap values;
+  private Int2DoubleOpenHashMap values;
 
   /** For serialization purposes only. */
   public RandomAccessSparseVector() {
@@ -44,7 +46,7 @@ public class RandomAccessSparseVector extends AbstractVector {
 
   public RandomAccessSparseVector(int cardinality, int initialCapacity) {
     super(cardinality);
-    values = new OpenIntDoubleHashMap(initialCapacity);
+    values = new Int2DoubleOpenHashMap(initialCapacity, .5f);
   }
 
   public RandomAccessSparseVector(Vector other) {
@@ -54,14 +56,14 @@ public class RandomAccessSparseVector extends AbstractVector {
     }
   }
 
-  private RandomAccessSparseVector(int cardinality, OpenIntDoubleHashMap values) {
+  private RandomAccessSparseVector(int cardinality, Int2DoubleOpenHashMap values) {
     super(cardinality);
     this.values = values;
   }
 
   public RandomAccessSparseVector(RandomAccessSparseVector other, boolean shallowCopy) {
     super(other.size());
-    values = shallowCopy ? other.values : (OpenIntDoubleHashMap)other.values.clone();
+    values = shallowCopy ? other.values : other.values.clone();
   }
 
   @Override
@@ -71,7 +73,7 @@ public class RandomAccessSparseVector extends AbstractVector {
 
   @Override
   public RandomAccessSparseVector clone() {
-    return new RandomAccessSparseVector(size(), (OpenIntDoubleHashMap) values.clone());
+    return new RandomAccessSparseVector(size(), values.clone());
   }
 
   @Override
@@ -123,7 +125,7 @@ public class RandomAccessSparseVector extends AbstractVector {
   public void setQuick(int index, double value) {
     invalidateCachedLength();
     if (value == 0.0) {
-      values.removeKey(index);
+      values.remove(index);
     } else {
       values.put(index, value);
     }
@@ -132,7 +134,7 @@ public class RandomAccessSparseVector extends AbstractVector {
   @Override
   public void incrementQuick(int index, double increment) {
     invalidateCachedLength();
-    values.adjustOrPutValue(index, increment, increment);
+    values.addTo( index, increment);
   }
 
 
@@ -153,14 +155,9 @@ public class RandomAccessSparseVector extends AbstractVector {
 
   @Override
   public int getNumNonZeroElements() {
-    DoubleArrayList elementValues = values.values();
-    int numMappedElements = elementValues.size();
+    final DoubleIterator iterator = values.values().iterator();
     int numNonZeros = 0;
-    for (int index = 0; index < numMappedElements; index++) {
-      if (elementValues.getQuick(index) != 0) {
-        numNonZeros++;
-      }
-    }
+    for( int i = values.size(); i-- != 0; ) if ( iterator.nextDouble() != 0 ) numNonZeros++;
     return numNonZeros;
   }
 
@@ -190,6 +187,49 @@ public class RandomAccessSparseVector extends AbstractVector {
   }
    */
 
+  private final class NonZeroIterator implements Iterator<Element> {
+    final ObjectIterator<Int2DoubleMap.Entry> fastIterator = values.int2DoubleEntrySet().fastIterator();
+    final RandomAccessElement element = new RandomAccessElement( fastIterator );
+
+    @Override
+    public boolean hasNext() {
+      return fastIterator.hasNext();
+    }
+
+    @Override
+    public Element next() {
+      if ( ! hasNext() ) throw new NoSuchElementException();
+      element.entry = fastIterator.next();
+      return element;
+    }
+}
+
+  final class RandomAccessElement implements Element {
+    Int2DoubleMap.Entry entry;
+    final ObjectIterator<Int2DoubleMap.Entry> fastIterator;
+
+    public RandomAccessElement( ObjectIterator<Entry> fastIterator ) {
+      super();
+      this.fastIterator = fastIterator;
+    }
+
+    @Override
+    public double get() {
+      return entry.getDoubleValue();
+    }
+
+    @Override
+    public int index() {
+      return entry.getIntKey();
+    }
+
+    @Override
+    public void set( double value ) {
+      invalidateCachedLength();
+      if (value == 0.0) fastIterator.remove();
+      else entry.setValue( value );
+    }
+  }
   /**
    * NOTE: this implementation reuses the Vector.Element instance for each call of next(). If you need to preserve the
    * instance, you need to make a copy of it
@@ -199,7 +239,7 @@ public class RandomAccessSparseVector extends AbstractVector {
    */
   @Override
   public Iterator<Element> iterateNonZero() {
-    return new NonDefaultIterator();
+    return new NonZeroIterator();
   }
 
   @Override
@@ -207,54 +247,30 @@ public class RandomAccessSparseVector extends AbstractVector {
     return new AllIterator();
   }
 
-  private final class NonDefaultIterator implements Iterator<Element> {
-    private final class NonDefaultElement implements Element {
-      @Override
-      public double get() {
-        return mapElement.get();
-      }
-
-      @Override
-      public int index() {
-        return mapElement.index();
-      }
-
-      @Override
-      public void set(double value) {
-        invalidateCachedLength();
-        mapElement.set(value);
-      }
-    }
-
-
-    private MapElement mapElement;
-    private final NonDefaultElement element = new NonDefaultElement();
-
-    private final Iterator<MapElement> iterator;
-
-    private NonDefaultIterator() {
-      this.iterator = values.iterator();
-    }
+  final class GeneralElement implements Element {
+    int index;
+    double value;
 
     @Override
-    public boolean hasNext() {
-      return iterator.hasNext();
+    public double get() {
+      return value;
     }
 
     @Override
-    public Element next() {
-      mapElement = iterator.next(); // This will throw an exception at the end of enumeration.
-      return element;
+    public int index() {
+      return index;
     }
 
     @Override
-    public void remove() {
-      throw new UnsupportedOperationException();
+    public void set( double value ) {
+      invalidateCachedLength();
+      if (value == 0.0) values.remove( index );
+      else values.put( index, value );
     }
-  }
+}
 
   private final class AllIterator implements Iterator<Element> {
-    private final RandomAccessElement element = new RandomAccessElement();
+    private final GeneralElement element = new GeneralElement();
 
     private AllIterator() {
       element.index = -1;
@@ -270,7 +286,7 @@ public class RandomAccessSparseVector extends AbstractVector {
       if (!hasNext()) {
         throw new NoSuchElementException();
       }
-      element.index++;
+      element.value = values.get( ++element.index );
       return element;
     }
 
@@ -279,28 +295,4 @@ public class RandomAccessSparseVector extends AbstractVector {
       throw new UnsupportedOperationException();
     }
   }
-
-  private final class RandomAccessElement implements Element {
-    int index;
-
-    @Override
-    public double get() {
-      return values.get(index);
-    }
-
-    @Override
-    public int index() {
-      return index;
-    }
-
-    @Override
-    public void set(double value) {
-      invalidateCachedLength();
-      if (value == 0.0) {
-        values.removeKey(index);
-      } else {
-        values.put(index, value);
-      }
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
index 088bba0..ecc005d 100644
--- a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
+++ b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
@@ -50,7 +50,7 @@ public final class TestRandomAccessSparseVector extends AbstractVectorTest<Rando
     w.set(13, 100500.);
     w.set(19, 3.141592);
 
-    for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 2))) {
+    for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 1))) {
       String[] tokens = token.split(":");
       assertEquals(Double.parseDouble(tokens[1]), w.get(Integer.parseInt(tokens[0])), 0.0);
     }

http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/test/java/org/apache/mahout/math/VectorTest.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/VectorTest.java b/math/src/test/java/org/apache/mahout/math/VectorTest.java
index 67dc1e9..d355499 100644
--- a/math/src/test/java/org/apache/mahout/math/VectorTest.java
+++ b/math/src/test/java/org/apache/mahout/math/VectorTest.java
@@ -18,6 +18,7 @@
 package org.apache.mahout.math;
 
 import java.util.Collection;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.NoSuchElementException;
 import java.util.Set;
@@ -919,17 +920,23 @@ public final class VectorTest extends MahoutTestCase {
     Iterator<Element> it = vector.nonZeroes().iterator();
     Element element = null;
     int i = 0;
+    HashSet<Integer> indexes = new HashSet<Integer>();
     while (it.hasNext()) {  // hasNext is called more often than next
       if (i % 2 == 0) {
         element = it.next();
+        indexes.add(element.index());
       }
       //noinspection ConstantConditions
-      assertEquals(element.index(), 2* (i/2));
-      assertEquals(element.get(), vector.get(2* (i/2)), 0);
+      assertEquals(element.get(), vector.get(element.index()), 0);
       ++i;
     }
     assertEquals(7, i);  // Last element is print only once.
-
+	assertEquals(4, indexes.size());
+	assertTrue(indexes.contains(0));
+	assertTrue(indexes.contains(2));
+	assertTrue(indexes.contains(4));
+	assertTrue(indexes.contains(6));
+	
     // Test all iterator.
     it = vector.all().iterator();
     element = null;