You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2016/03/08 21:33:16 UTC
mahout git commit: MAHOUT-1640:Better collections would significantly
improve vector-operation speed, closes apache/mahout#81
Repository: mahout
Updated Branches:
refs/heads/master a3cdff6d6 -> 284651dc0
MAHOUT-1640:Better collections would significantly improve vector-operation speed, closes apache/mahout#81
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/284651dc
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/284651dc
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/284651dc
Branch: refs/heads/master
Commit: 284651dc05c0c57ba094ba738a20db6cbe3bbcd7
Parents: a3cdff6
Author: smarthi <sm...@apache.org>
Authored: Tue Mar 8 15:32:57 2016 -0500
Committer: smarthi <sm...@apache.org>
Committed: Tue Mar 8 15:32:57 2016 -0500
----------------------------------------------------------------------
LICENSE.txt | 2 +-
math/pom.xml | 6 +
.../mahout/math/RandomAccessSparseVector.java | 154 +++++++++----------
.../math/TestRandomAccessSparseVector.java | 2 +-
.../java/org/apache/mahout/math/VectorTest.java | 13 +-
5 files changed, 91 insertions(+), 86 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/LICENSE.txt
----------------------------------------------------------------------
diff --git a/LICENSE.txt b/LICENSE.txt
index dc49724..92ceb9b 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -2,7 +2,7 @@
The following license applies to software from the
Apache Software Foundation.
It also applies to software from the Uncommons Watchmaker and Math
-projects, Google Guava software, and MongoDB.org driver software
+projects, Google Guava software, MongoDB.org driver software and fastutil.
--------------------------------------------------------------------------
Apache License
http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/pom.xml
----------------------------------------------------------------------
diff --git a/math/pom.xml b/math/pom.xml
index 0b946d5..fda8f18 100644
--- a/math/pom.xml
+++ b/math/pom.xml
@@ -146,6 +146,12 @@
</dependency>
<dependency>
+ <groupId>it.unimi.dsi</groupId>
+ <artifactId>fastutil</artifactId>
+ <version>7.0.11</version>
+ </dependency>
+
+ <dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
index 3efac7e..9316915 100644
--- a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
+++ b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
@@ -17,21 +17,23 @@
package org.apache.mahout.math;
+import it.unimi.dsi.fastutil.doubles.DoubleIterator;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
+import it.unimi.dsi.fastutil.ints.Int2DoubleMap.Entry;
+import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.objects.ObjectIterator;
+
import java.util.Iterator;
import java.util.NoSuchElementException;
-import org.apache.mahout.math.list.DoubleArrayList;
-import org.apache.mahout.math.map.OpenIntDoubleHashMap;
-import org.apache.mahout.math.map.OpenIntDoubleHashMap.MapElement;
import org.apache.mahout.math.set.AbstractSet;
-
/** Implements vector that only stores non-zero doubles */
public class RandomAccessSparseVector extends AbstractVector {
private static final int INITIAL_CAPACITY = 11;
- private OpenIntDoubleHashMap values;
+ private Int2DoubleOpenHashMap values;
/** For serialization purposes only. */
public RandomAccessSparseVector() {
@@ -44,7 +46,7 @@ public class RandomAccessSparseVector extends AbstractVector {
public RandomAccessSparseVector(int cardinality, int initialCapacity) {
super(cardinality);
- values = new OpenIntDoubleHashMap(initialCapacity);
+ values = new Int2DoubleOpenHashMap(initialCapacity, .5f);
}
public RandomAccessSparseVector(Vector other) {
@@ -54,14 +56,14 @@ public class RandomAccessSparseVector extends AbstractVector {
}
}
- private RandomAccessSparseVector(int cardinality, OpenIntDoubleHashMap values) {
+ private RandomAccessSparseVector(int cardinality, Int2DoubleOpenHashMap values) {
super(cardinality);
this.values = values;
}
public RandomAccessSparseVector(RandomAccessSparseVector other, boolean shallowCopy) {
super(other.size());
- values = shallowCopy ? other.values : (OpenIntDoubleHashMap)other.values.clone();
+ values = shallowCopy ? other.values : other.values.clone();
}
@Override
@@ -71,7 +73,7 @@ public class RandomAccessSparseVector extends AbstractVector {
@Override
public RandomAccessSparseVector clone() {
- return new RandomAccessSparseVector(size(), (OpenIntDoubleHashMap) values.clone());
+ return new RandomAccessSparseVector(size(), values.clone());
}
@Override
@@ -123,7 +125,7 @@ public class RandomAccessSparseVector extends AbstractVector {
public void setQuick(int index, double value) {
invalidateCachedLength();
if (value == 0.0) {
- values.removeKey(index);
+ values.remove(index);
} else {
values.put(index, value);
}
@@ -132,7 +134,7 @@ public class RandomAccessSparseVector extends AbstractVector {
@Override
public void incrementQuick(int index, double increment) {
invalidateCachedLength();
- values.adjustOrPutValue(index, increment, increment);
+ values.addTo( index, increment);
}
@@ -153,14 +155,9 @@ public class RandomAccessSparseVector extends AbstractVector {
@Override
public int getNumNonZeroElements() {
- DoubleArrayList elementValues = values.values();
- int numMappedElements = elementValues.size();
+ final DoubleIterator iterator = values.values().iterator();
int numNonZeros = 0;
- for (int index = 0; index < numMappedElements; index++) {
- if (elementValues.getQuick(index) != 0) {
- numNonZeros++;
- }
- }
+ for( int i = values.size(); i-- != 0; ) if ( iterator.nextDouble() != 0 ) numNonZeros++;
return numNonZeros;
}
@@ -190,6 +187,49 @@ public class RandomAccessSparseVector extends AbstractVector {
}
*/
+ private final class NonZeroIterator implements Iterator<Element> {
+ final ObjectIterator<Int2DoubleMap.Entry> fastIterator = values.int2DoubleEntrySet().fastIterator();
+ final RandomAccessElement element = new RandomAccessElement( fastIterator );
+
+ @Override
+ public boolean hasNext() {
+ return fastIterator.hasNext();
+ }
+
+ @Override
+ public Element next() {
+ if ( ! hasNext() ) throw new NoSuchElementException();
+ element.entry = fastIterator.next();
+ return element;
+ }
+}
+
+ final class RandomAccessElement implements Element {
+ Int2DoubleMap.Entry entry;
+ final ObjectIterator<Int2DoubleMap.Entry> fastIterator;
+
+ public RandomAccessElement( ObjectIterator<Entry> fastIterator ) {
+ super();
+ this.fastIterator = fastIterator;
+ }
+
+ @Override
+ public double get() {
+ return entry.getDoubleValue();
+ }
+
+ @Override
+ public int index() {
+ return entry.getIntKey();
+ }
+
+ @Override
+ public void set( double value ) {
+ invalidateCachedLength();
+ if (value == 0.0) fastIterator.remove();
+ else entry.setValue( value );
+ }
+ }
/**
* NOTE: this implementation reuses the Vector.Element instance for each call of next(). If you need to preserve the
* instance, you need to make a copy of it
@@ -199,7 +239,7 @@ public class RandomAccessSparseVector extends AbstractVector {
*/
@Override
public Iterator<Element> iterateNonZero() {
- return new NonDefaultIterator();
+ return new NonZeroIterator();
}
@Override
@@ -207,54 +247,30 @@ public class RandomAccessSparseVector extends AbstractVector {
return new AllIterator();
}
- private final class NonDefaultIterator implements Iterator<Element> {
- private final class NonDefaultElement implements Element {
- @Override
- public double get() {
- return mapElement.get();
- }
-
- @Override
- public int index() {
- return mapElement.index();
- }
-
- @Override
- public void set(double value) {
- invalidateCachedLength();
- mapElement.set(value);
- }
- }
-
-
- private MapElement mapElement;
- private final NonDefaultElement element = new NonDefaultElement();
-
- private final Iterator<MapElement> iterator;
-
- private NonDefaultIterator() {
- this.iterator = values.iterator();
- }
+ final class GeneralElement implements Element {
+ int index;
+ double value;
@Override
- public boolean hasNext() {
- return iterator.hasNext();
+ public double get() {
+ return value;
}
@Override
- public Element next() {
- mapElement = iterator.next(); // This will throw an exception at the end of enumeration.
- return element;
+ public int index() {
+ return index;
}
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public void set( double value ) {
+ invalidateCachedLength();
+ if (value == 0.0) values.remove( index );
+ else values.put( index, value );
}
- }
+}
private final class AllIterator implements Iterator<Element> {
- private final RandomAccessElement element = new RandomAccessElement();
+ private final GeneralElement element = new GeneralElement();
private AllIterator() {
element.index = -1;
@@ -270,7 +286,7 @@ public class RandomAccessSparseVector extends AbstractVector {
if (!hasNext()) {
throw new NoSuchElementException();
}
- element.index++;
+ element.value = values.get( ++element.index );
return element;
}
@@ -279,28 +295,4 @@ public class RandomAccessSparseVector extends AbstractVector {
throw new UnsupportedOperationException();
}
}
-
- private final class RandomAccessElement implements Element {
- int index;
-
- @Override
- public double get() {
- return values.get(index);
- }
-
- @Override
- public int index() {
- return index;
- }
-
- @Override
- public void set(double value) {
- invalidateCachedLength();
- if (value == 0.0) {
- values.removeKey(index);
- } else {
- values.put(index, value);
- }
- }
- }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
index 088bba0..ecc005d 100644
--- a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
+++ b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
@@ -50,7 +50,7 @@ public final class TestRandomAccessSparseVector extends AbstractVectorTest<Rando
w.set(13, 100500.);
w.set(19, 3.141592);
- for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 2))) {
+ for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 1))) {
String[] tokens = token.split(":");
assertEquals(Double.parseDouble(tokens[1]), w.get(Integer.parseInt(tokens[0])), 0.0);
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/284651dc/math/src/test/java/org/apache/mahout/math/VectorTest.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/VectorTest.java b/math/src/test/java/org/apache/mahout/math/VectorTest.java
index 67dc1e9..d355499 100644
--- a/math/src/test/java/org/apache/mahout/math/VectorTest.java
+++ b/math/src/test/java/org/apache/mahout/math/VectorTest.java
@@ -18,6 +18,7 @@
package org.apache.mahout.math;
import java.util.Collection;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Set;
@@ -919,17 +920,23 @@ public final class VectorTest extends MahoutTestCase {
Iterator<Element> it = vector.nonZeroes().iterator();
Element element = null;
int i = 0;
+ HashSet<Integer> indexes = new HashSet<Integer>();
while (it.hasNext()) { // hasNext is called more often than next
if (i % 2 == 0) {
element = it.next();
+ indexes.add(element.index());
}
//noinspection ConstantConditions
- assertEquals(element.index(), 2* (i/2));
- assertEquals(element.get(), vector.get(2* (i/2)), 0);
+ assertEquals(element.get(), vector.get(element.index()), 0);
++i;
}
assertEquals(7, i); // Last element is print only once.
-
+ assertEquals(4, indexes.size());
+ assertTrue(indexes.contains(0));
+ assertTrue(indexes.contains(2));
+ assertTrue(indexes.contains(4));
+ assertTrue(indexes.contains(6));
+
// Test all iterator.
it = vector.all().iterator();
element = null;