You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by yx...@apache.org on 2014/04/11 04:01:34 UTC
svn commit: r1586531 - in /hama/trunk: CHANGES.txt
commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java
commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java
Author: yxjiang
Date: Fri Apr 11 02:01:33 2014
New Revision: 1586531
URL: http://svn.apache.org/r1586531
Log:
HAMA-863: Implement SparseDoubleVector
Added:
hama/trunk/commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java
hama/trunk/commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java
Modified:
hama/trunk/CHANGES.txt
Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1586531&r1=1586530&r2=1586531&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Fri Apr 11 02:01:33 2014
@@ -3,6 +3,7 @@ Hama Change Log
Release 0.7.0 (unreleased changes)
NEW FEATURES
+ HAMA-863: Implement SparseVector (Yexi Jiang)
BUG FIXES
Added: hama/trunk/commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java?rev=1586531&view=auto
==============================================================================
--- hama/trunk/commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java (added)
+++ hama/trunk/commons/src/main/java/org/apache/hama/commons/math/SparseDoubleVector.java Fri Apr 11 02:01:33 2014
@@ -0,0 +1,817 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.commons.math;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+
+/**
+ * The implementation of SparseVector.
+ *
+ */
+public class SparseDoubleVector implements DoubleVector {
+
+ private int dimension;
+ private double defaultValue; // 0 by default
+ private Map<Integer, Double> elements; // the non-default value
+
+ /**
+ * Initialize a sparse vector with given dimension with default value 0.
+ *
+ * @param dimension
+ */
+ public SparseDoubleVector(int dimension) {
+ this(dimension, 0.0);
+ }
+
+ /**
+ * Initialize a sparse vector with given dimension and given default value 0.
+ *
+ * @param dimension
+ * @param defaultValue
+ */
+ public SparseDoubleVector(int dimension, double defaultValue) {
+ this.elements = new HashMap<Integer, Double>();
+ this.defaultValue = defaultValue;
+ this.dimension = dimension;
+ }
+
+ /**
+ * Get the value of a given index.
+ */
+ @Override
+ public double get(int index) {
+ Preconditions.checkArgument(index < this.dimension,
+ "Index out of max allowd dimension of sparse vector.");
+ Double val = this.elements.get(index);
+ if (val == null) {
+ val = this.defaultValue;
+ }
+ return val;
+ }
+
+ /**
+ * Get the dimension of the vector.
+ */
+ @Override
+ public int getLength() {
+ return this.getDimension();
+ }
+
+ /**
+ * Get the dimension of the vector.
+ */
+ @Override
+ public int getDimension() {
+ return this.dimension;
+ }
+
+ /**
+ * Set the value of a given index.
+ */
+ @Override
+ public void set(int index, double value) {
+ Preconditions.checkArgument(index < this.dimension, String.format(
+ "Index %d out of max allowd dimension %d of sparse vector.", index,
+ this.dimension));
+ if (value == this.defaultValue) {
+ this.elements.remove(index);
+ } else {
+ this.elements.put(index, value);
+ }
+ }
+
+ /**
+ * Apply a function to the copy of current vector and return the result.
+ */
+ @Override
+ public DoubleVector applyToElements(DoubleFunction func) {
+ SparseDoubleVector newVec = new SparseDoubleVector(this.dimension,
+ func.apply(this.defaultValue));
+ // apply function to all non-empty entries
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ newVec.elements.put(entry.getKey(), func.apply(entry.getValue()));
+ }
+
+ return newVec;
+ }
+
+ /**
+ * Apply a binary function to the copy of current vector with another vector
+ * and then return the result.
+ */
+ @Override
+ public DoubleVector applyToElements(DoubleVector other,
+ DoubleDoubleFunction func) {
+
+ Preconditions.checkArgument(this.getDimension() == other.getDimension(),
+ "Dimension of two vectors should be equal.");
+
+ double otherDefaultValue = 0.0;
+ if (other instanceof SparseDoubleVector) {
+ otherDefaultValue = ((SparseDoubleVector) other).defaultValue;
+ }
+ double newDefaultValue = this.defaultValue;
+ if (other instanceof SparseDoubleVector) {
+ newDefaultValue = func.apply(this.defaultValue, otherDefaultValue);
+ }
+
+ SparseDoubleVector newVec = new SparseDoubleVector(this.dimension,
+ newDefaultValue);
+
+ Iterator<DoubleVectorElement> thisItr = this.iterateNonDefault();
+ Iterator<DoubleVectorElement> otherItr = other.iterateNonDefault();
+
+ DoubleVectorElement thisCur = null;
+ if (thisItr.hasNext()) {
+ thisCur = thisItr.next();
+ }
+ DoubleVectorElement otherCur = null;
+ if (otherItr.hasNext()) {
+ otherCur = otherItr.next();
+ }
+
+ while (thisCur != null || otherCur != null) {
+ if (thisCur == null) { // the iterator of current vector reaches the end
+ newVec.set(otherCur.getIndex(),
+ func.apply(this.defaultValue, otherCur.getValue()));
+ if (newVec.get(otherCur.getIndex()) == newVec.defaultValue) {
+ // remove if the value equals the default value
+ newVec.elements.remove(otherCur.getIndex());
+ }
+ if (otherItr.hasNext()) {
+ otherCur = otherItr.next();
+ } else {
+ otherCur = null;
+ }
+ } else if (otherCur == null) {
+ newVec.set(thisCur.getIndex(),
+ func.apply(thisCur.getValue(), otherDefaultValue));
+ if (newVec.get(thisCur.getIndex()) == newVec.defaultValue) {
+ // remove if the value equals the default value
+ newVec.elements.remove(thisCur.getIndex());
+ }
+ if (thisItr.hasNext()) {
+ thisCur = thisItr.next();
+ } else {
+ thisCur = null;
+ }
+ } else {
+ int curIdx = 0;
+ if (thisCur.getIndex() < otherCur.getIndex()) {
+ curIdx = thisCur.getIndex();
+ newVec.set(curIdx, func.apply(thisCur.getValue(), otherDefaultValue));
+ if (thisItr.hasNext()) {
+ thisCur = thisItr.next();
+ } else {
+ thisCur = null;
+ }
+ } else if (thisCur.getIndex() > otherCur.getIndex()) {
+ curIdx = otherCur.getIndex();
+ newVec
+ .set(curIdx, func.apply(this.defaultValue, otherCur.getValue()));
+ if (otherItr.hasNext()) {
+ otherCur = otherItr.next();
+ } else {
+ otherCur = null;
+ }
+ } else {
+ curIdx = thisCur.getIndex();
+ newVec.set(curIdx,
+ func.apply(thisCur.getValue(), otherCur.getValue()));
+ if (thisItr.hasNext()) {
+ thisCur = thisItr.next();
+ } else {
+ thisCur = null;
+ }
+ if (otherItr.hasNext()) {
+ otherCur = otherItr.next();
+ } else {
+ otherCur = null;
+ }
+ }
+ if (newVec.get(curIdx) == newVec.defaultValue) {
+ // remove if the value equals the default value
+ newVec.elements.remove(curIdx);
+ }
+ }
+ }
+
+ return newVec;
+ }
+
+ /**
+ * Add another vector.
+ */
+ @Override
+ public DoubleVector addUnsafe(DoubleVector vector) {
+ return this.applyToElements(vector, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 + x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Add another vector to the copy of the current vector and return the result.
+ */
+ @Override
+ public DoubleVector add(DoubleVector vector) {
+ Preconditions.checkArgument(this.dimension == vector.getDimension(),
+ "Dimensions of two vectors are not the same.");
+ return this.addUnsafe(vector);
+ }
+
+ /**
+ * Add a scalar.
+ */
+ @Override
+ public DoubleVector add(double scalar) {
+ final double val = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value + val;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Subtract a vector from current vector.
+ */
+ @Override
+ public DoubleVector subtractUnsafe(DoubleVector vector) {
+ return this.applyToElements(vector, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 - x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ return 0;
+ }
+ });
+ }
+
+ /**
+ * Subtract a vector from current vector.
+ */
+ @Override
+ public DoubleVector subtract(DoubleVector vector) {
+ Preconditions.checkArgument(this.dimension == vector.getDimension(),
+ "Dimensions of two vector are not the same.");
+ return this.subtractUnsafe(vector);
+ }
+
+ /**
+ * Subtract a scalar from a copy of the current vector.
+ */
+ @Override
+ public DoubleVector subtract(double scalar) {
+ final double val = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value - val;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Subtract a scalar from a copy of the current vector and return the result.
+ */
+ @Override
+ public DoubleVector subtractFrom(double scalar) {
+ final double val = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return val - value;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Multiply a copy of the current vector by a scalar and return the result.
+ */
+ @Override
+ public DoubleVector multiply(double scalar) {
+ final double val = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return val * value;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Multiply a copy of the current vector with another vector and return the
+ * result.
+ */
+ @Override
+ public DoubleVector multiplyUnsafe(DoubleVector vector) {
+ return this.applyToElements(vector, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 * x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Multiply a copy of the current vector with another vector and return the
+ * result.
+ */
+ @Override
+ public DoubleVector multiply(DoubleVector vector) {
+ Preconditions.checkArgument(this.dimension == vector.getDimension(),
+ "Dimensions of two vectors are not the same.");
+ return this.multiplyUnsafe(vector);
+ }
+
+ /**
+ * Multiply a copy of the current vector with a matrix and return the result,
+ * i.e. r = v * M.
+ */
+ @Override
+ public DoubleVector multiply(DoubleMatrix matrix) {
+ Preconditions
+ .checkArgument(this.dimension == matrix.getColumnCount(),
+ "The dimension of vector does not equal to the dimension of the matrix column.");
+ return this.multiplyUnsafe(matrix);
+ }
+
+ /**
+ * Multiply a copy of the current vector with a matrix and return the result,
+ * i.e. r = v * M.
+ */
+ @Override
+ public DoubleVector multiplyUnsafe(DoubleMatrix matrix) {
+ // currently the result is a dense double vector
+ DoubleVector res = new DenseDoubleVector(matrix.getColumnCount());
+ int columns = matrix.getColumnCount();
+ for (int i = 0; i < columns; ++i) {
+ res.set(i, this.dotUnsafe(matrix.getColumnVector(i)));
+ }
+ return res;
+ }
+
+ /**
+ * Divide a copy of the current vector by a scala.
+ */
+ @Override
+ public DoubleVector divide(double scalar) {
+ Preconditions.checkArgument(scalar != 0, "Scalar cannot be 0.");
+ final double factor = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value / factor;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#divideFrom(double)
+ */
+ @Override
+ public DoubleVector divideFrom(double scalar) {
+ Preconditions.checkArgument(scalar != 0, "Scalar cannot be 0.");
+ final double factor = scalar;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return factor / value; // value can be 0!
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Apply the power to each element of a copy of the current vector.
+ */
+ @Override
+ public DoubleVector pow(int x) {
+ final int p = x;
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return Math.pow(value, p);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Apply the abs elementwise to a copy of current vector.
+ */
+ @Override
+ public DoubleVector abs() {
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return Math.abs(value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Apply the sqrt elementwise to a copy of current vector.
+ */
+ @Override
+ public DoubleVector sqrt() {
+ return this.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return Math.sqrt(value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+ }
+
+ /**
+ * Get the sum of all the elements.
+ */
+ @Override
+ public double sum() {
+ double sum = 0.0;
+ Iterator<DoubleVectorElement> itr = this.iterate();
+ while (itr.hasNext()) {
+ sum += itr.next().getValue();
+ }
+ return sum;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see
+ * org.apache.hama.commons.math.DoubleVector#dotUnsafe(org.apache.hama.commons
+ * .math.DoubleVector)
+ */
+ @Override
+ public double dotUnsafe(DoubleVector vector) {
+ return this.multiplyUnsafe(vector).sum();
+ }
+
+ /**
+ * Apply dot onto a copy of current vector and another vector.
+ */
+ @Override
+ public double dot(DoubleVector vector) {
+ Preconditions.checkArgument(this.dimension == vector.getDimension(),
+ "Dimensions of two vectors are not equal.");
+ return this.dotUnsafe(vector);
+ }
+
+ /**
+ * Obtain a sub-vector from the current vector.
+ */
+ @Override
+ public DoubleVector slice(int length) {
+ Preconditions.checkArgument(length >= 0 && length < this.dimension,
+ String.format("length must be in range [0, %d).", this.dimension));
+ return this.sliceUnsafe(length);
+ }
+
+ /**
+ * Obtain a sub-vector from the current vector.
+ */
+ @Override
+ public DoubleVector sliceUnsafe(int length) {
+ DoubleVector newVector = new SparseDoubleVector(length, this.defaultValue);
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ if (entry.getKey() >= length) {
+ continue;
+ }
+ newVector.set(entry.getKey(), entry.getValue());
+ }
+ return newVector;
+ }
+
+ /**
+ * Obtain a sub-vector of the current vector
+ *
+ * @param start inclusive
+ * @param end inclusive
+ */
+ @Override
+ public DoubleVector slice(int start, int end) {
+ Preconditions.checkArgument(start >= 0 && end < this.dimension, String
+ .format("Start and end range should be in [0, %d].", this.dimension));
+ return this.sliceUnsafe(start, end);
+ }
+
+ /**
+ * @param start inclusive
+ * @param end inclusive
+ */
+ @Override
+ public DoubleVector sliceUnsafe(int start, int end) {
+ SparseDoubleVector slicedVec = new SparseDoubleVector(end - start + 1,
+ this.defaultValue);
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ if (entry.getKey() >= start && entry.getKey() <= end) {
+ slicedVec.elements.put(entry.getKey() - start, entry.getValue());
+ }
+ if (entry.getKey() > end) {
+ continue;
+ }
+ }
+ return slicedVec;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#max()
+ */
+ @Override
+ public double max() {
+ double max = this.defaultValue;
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ max = Math.max(max, entry.getValue());
+ }
+ return max;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#min()
+ */
+ @Override
+ public double min() {
+ double min = this.defaultValue;
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ min = Math.min(min, entry.getValue());
+ }
+ return min;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#toArray()
+ */
+ @Override
+ public double[] toArray() {
+ throw new UnsupportedOperationException(
+ "SparseDoubleVector does not support toArray() method.");
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#deepCopy()
+ */
+ @Override
+ public DoubleVector deepCopy() {
+ SparseDoubleVector copy = new SparseDoubleVector(this.dimension);
+ copy.elements = new HashMap<Integer, Double>(this.elements.size());
+ copy.elements.putAll(this.elements);
+ return copy;
+ }
+
+ /**
+ * Generate the iterator that iterates the non-default values.
+ */
+ @Override
+ public Iterator<DoubleVectorElement> iterateNonDefault() {
+ return new NonDefaultIterator();
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#iterate()
+ */
+ @Override
+ public Iterator<DoubleVectorElement> iterate() {
+ return new DefaultIterator();
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#isSparse()
+ */
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#isNamed()
+ */
+ @Override
+ public boolean isNamed() {
+ return false;
+ }
+
+ /*
+ * (non-Javadoc)
+ * @see org.apache.hama.commons.math.DoubleVector#getName()
+ */
+ @Override
+ public String getName() {
+ return null;
+ }
+
+ /**
+ * Non-zero iterator for vector elements.
+ */
+ private final class NonDefaultIterator extends
+ AbstractIterator<DoubleVectorElement> {
+ private final DoubleVectorElement element = new DoubleVectorElement();
+
+ private final int entryDimension;
+ private final Map<Integer, Double> entries;
+ private final double defaultV = defaultValue;
+ private int currentIndex = 0;
+
+ public NonDefaultIterator() {
+ this.entryDimension = dimension;
+ this.entries = elements;
+ }
+
+ @Override
+ protected DoubleVectorElement computeNext() {
+ DoubleVectorElement elem = getNext();
+ // skip the entries with default values
+ while (elem != null && elem.getValue() == this.defaultV) {
+ elem = getNext();
+ }
+ return elem;
+ }
+
+ private DoubleVectorElement getNext() {
+ if (currentIndex < entryDimension) {
+ Double value = entries.get(currentIndex);
+ element.setIndex(currentIndex);
+ if (value == null) {
+ element.setValue(defaultV);
+ } else {
+ element.setValue(value);
+ }
+ ++currentIndex;
+ return element;
+ } else {
+ return endOfData();
+ }
+ }
+ }
+
+ private final class DefaultIterator extends
+ AbstractIterator<DoubleVectorElement> {
+
+ private final DoubleVectorElement element = new DoubleVectorElement();
+
+ private final int entryDimension;
+ private final Map<Integer, Double> entries;
+ private final double defaultV = defaultValue;
+ private int currentIndex = 0;
+
+ public DefaultIterator() {
+ this.entryDimension = dimension;
+ this.entries = elements;
+ }
+
+ @Override
+ protected DoubleVectorElement computeNext() {
+ if (currentIndex < entryDimension) {
+ Double value = entries.get(currentIndex);
+ element.setIndex(currentIndex);
+ if (value == null) {
+ element.setValue(defaultV);
+ } else {
+ element.setValue(value);
+ }
+ ++currentIndex;
+ return element;
+ } else {
+ return endOfData();
+ }
+ }
+
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ Iterator<DoubleVectorElement> itr = this.iterate();
+ sb.append('{');
+ while (itr.hasNext()) {
+ DoubleVectorElement elem = itr.next();
+ sb.append(String.format("%s: %f, ", elem.getIndex(), elem.getValue()));
+ }
+ sb.replace(sb.length() - 2, sb.length() - 1, "}");
+ return sb.toString();
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ long temp;
+ temp = Double.doubleToLongBits(defaultValue);
+ result = prime * result + (int) (temp ^ (temp >>> 32));
+ result = prime * result + dimension;
+ result = prime * result + ((elements == null) ? 0 : elements.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof DoubleVector) {
+ DoubleVector otherVec = (DoubleVector) other;
+ if (this.dimension != otherVec.getDimension()) {
+ return false;
+ }
+ // check non-default entries first
+ for (Map.Entry<Integer, Double> entry : this.elements.entrySet()) {
+ if (Math.abs(entry.getValue() - otherVec.get(entry.getKey())) > 0.000001) {
+ return false;
+ }
+ }
+ // check default values
+ for (int i = 0; i < this.dimension; ++i) {
+ if (Math.abs(this.get(i) - otherVec.get(i)) > 0.000001) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+ return false;
+ }
+}
Added: hama/trunk/commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java?rev=1586531&view=auto
==============================================================================
--- hama/trunk/commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java (added)
+++ hama/trunk/commons/src/test/java/org/apache/hama/commons/math/TestSparseDoubleVector.java Fri Apr 11 02:01:33 2014
@@ -0,0 +1,393 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hama.commons.math;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.util.Iterator;
+
+import org.apache.hama.commons.math.DoubleVector.DoubleVectorElement;
+import org.junit.Test;
+
+/**
+ * The test cases of {@link SparseDoubleVector}.
+ *
+ */
+public class TestSparseDoubleVector {
+
+
+
+ @Test
+ public void testBasic() {
+ DoubleVector v1 = new SparseDoubleVector(10);
+ for (int i = 0; i < 10; ++i) {
+ assertEquals(v1.get(i), 0.0, 0.000001);
+ }
+
+ DoubleVector v2 = new SparseDoubleVector(10, 2.5);
+ for (int i = 0; i < 10; ++i) {
+ assertEquals(v2.get(i), 2.5, 0.000001);
+ }
+
+ assertEquals(v1.getDimension(), 10);
+ assertEquals(v2.getLength(), 10);
+
+ v1.set(5, 2);
+ assertEquals(v1.get(5), 2, 0.000001);
+ }
+
+ @Test
+ public void testIterators() {
+ DoubleVector v1 = new SparseDoubleVector(10, 5.5);
+ Iterator<DoubleVectorElement> itr1 = v1.iterate();
+ int idx1 = 0;
+ while (itr1.hasNext()) {
+ DoubleVectorElement elem = itr1.next();
+ assertEquals(idx1++, elem.getIndex());
+ assertEquals(5.5, elem.getValue(), 0.000001);
+ }
+
+ v1.set(2, 20);
+ v1.set(6, 30);
+
+ Iterator<DoubleVectorElement> itr2 = v1.iterateNonDefault();
+ DoubleVectorElement elem = itr2.next();
+ assertEquals(2, elem.getIndex());
+ assertEquals(20, elem.getValue(), 0.000001);
+ elem = itr2.next();
+ assertEquals(6, elem.getIndex());
+ assertEquals(30, elem.getValue(), 0.000001);
+
+ assertFalse(itr2.hasNext());
+ }
+
+ @Test
+ public void testApplyToElements() {
+ // v1 = {5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5}
+ DoubleVector v1 = new SparseDoubleVector(10, 5.5);
+
+ // v2 = {60.6, 60.5, 60.5, 60.5, 60.5, 60.5, 60.5, 60.5, 60.5, 60.5}
+ DoubleVector v2 = v1.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value * 11;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 0;
+ }
+ });
+
+ // v3 = {4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5}
+ DoubleVector v3 = v1.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value / 2 + 1.75;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 0;
+ }
+ });
+
+ // v4 = {66, 66, 66, 66, 66, 66, 66, 66, 66, 66}
+ DoubleVector v4 = v1.applyToElements(v2, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 + x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ return 0;
+ }
+ });
+
+ for (int i = 0; i < 10; ++i) {
+ assertEquals(v1.get(i), 5.5, 0.000001);
+ assertEquals(v2.get(i), 60.5, 0.000001);
+ assertEquals(v3.get(i), 4.5, 0.000001);
+ assertEquals(v4.get(i), 66, 0.000001);
+ }
+
+ // v3 = {4.5, 4.5, 4.5, 10, 4.5, 4.5, 10, 4.5, 200, 4.5}
+ v3.set(3, 10);
+ v3.set(6, 10);
+ v3.set(8, 200);
+
+ // v4 = {66, 66, 66, 66, 66, 100, 66, 66, 1, 66}
+ v4.set(5, 100);
+ v4.set(8, 1);
+
+ // v5 = {615, 615, 615, 560, 615, 955, 560, 615, -1990, 615}
+ DoubleVector v5 = v4.applyToElements(v3, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return (x1 - x2) * 10;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ return 0;
+ }
+ });
+
+ // v6 = {615, 615, 615, 560, 615, 955, 560, 615, -1990, 615}
+ DoubleVector v6 = new SparseDoubleVector(10, 615);
+ v6.set(3, 560);
+ v6.set(5, 955);
+ v6.set(6, 560);
+ v6.set(8, -1990);
+
+ for (int i = 0; i < v5.getDimension(); ++i) {
+ assertEquals(v5.get(i), v6.get(i), 0.000001);
+ }
+
+ // v7 = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}
+ DoubleVector v7 = new DenseDoubleVector(new double[] { 0.0, 1.0, 2.0, 3.0,
+ 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 });
+
+ DoubleVector v8 = v5.applyToElements(v7, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return (x1 + x2) * 3.3;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ return 0;
+ }
+ });
+
+ DoubleVector v9 = v6.applyToElements(v7, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double x1, double x2) {
+ return (x1 + x2) * 3.3;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ return 0;
+ }
+ });
+
+ for (int i = 0; i < v7.getDimension(); ++i) {
+ assertEquals(v8.get(i), v9.get(i), 0.000001);
+ }
+
+ }
+
+ @Test
+ public void testAdd() {
+ // addition of two sparse vectors
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10);
+ for (int i = 0; i < spVec2.getDimension(); ++i) {
+ spVec2.set(i, 1.5);
+ }
+
+ DoubleVector expRes1 = new SparseDoubleVector(10, 3.0);
+ assertEquals(expRes1, spVec1.add(spVec2));
+
+ // addition of one sparse vector and one dense vector
+ DoubleVector dsVec1 = new DenseDoubleVector(10);
+ for (int i = 0; i < dsVec1.getDimension(); ++i) {
+ dsVec1.set(i, 1.5);
+ }
+
+ DoubleVector expRes2 = new DenseDoubleVector(10);
+ for (int i = 0; i < expRes2.getDimension(); ++i) {
+ expRes2.set(i, 3.0);
+ }
+ assertEquals(expRes2, dsVec1.add(spVec2));
+ }
+
+ @Test
+ public void testSubtract() {
+ // subtract two sparse vectors
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 2.2);
+ for (int i = 0; i < spVec2.getDimension(); ++i) {
+ spVec2.set(i, 1.2);
+ }
+ DoubleVector expRes1 = new SparseDoubleVector(10, 0.3);
+ assertEquals(expRes1, spVec1.subtract(spVec2));
+
+ DoubleVector expRes2 = new SparseDoubleVector(10, -0.7);
+ assertEquals(expRes2, spVec1.subtract(spVec3));
+
+ // subtract one sparse vector from a dense vector
+ DoubleVector dsVec1 = new DenseDoubleVector(10);
+ for (int i = 0; i < dsVec1.getDimension(); ++i) {
+ dsVec1.set(i, 1.7);
+ }
+ DoubleVector expRes3 = new SparseDoubleVector(10, 0.2);
+ assertEquals(expRes3, dsVec1.subtract(spVec1));
+
+ // subtract one dense vector from a sparse vector
+ DoubleVector expRes4 = new SparseDoubleVector(10, -0.2);
+ assertEquals(expRes4, spVec1.subtract(dsVec1));
+ }
+
+ @Test
+ public void testMultiplyScala() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 2.2);
+
+ DoubleVector spRes1 = spVec1.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return value * 3;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+ });
+
+ assertEquals(spRes1, spVec1.multiply(3));
+ assertEquals(spVec2, spVec2.multiply(1000));
+ assertEquals(spVec2, spVec1.multiply(0));
+ assertEquals(spVec1, spVec3.multiply(1.5 / 2.2));
+ }
+
+ @Test
+ public void testMultiply() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 2.2);
+
+ DoubleVector spRes1 = spVec1.applyToElements(spVec3,
+ new DoubleDoubleFunction() {
+ @Override
+ public double apply(double first, double second) {
+ return first * second;
+ }
+
+ @Override
+ public double applyDerivative(double value, double second) {
+ throw new UnsupportedOperationException();
+ }
+ });
+
+ assertEquals(spRes1, spVec1.multiply(spVec3));
+ assertEquals(spVec2, spVec1.multiply(spVec2));
+ }
+
+ @Test
+ public void testDivide() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 6.0);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 2.2);
+
+ assertEquals(spVec3, spVec1.divide(1.5 / 2.2));
+ assertEquals(spVec2, spVec1.divideFrom(9));
+ }
+
+ @Test
+ public void testPow() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 1);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 2.25);
+
+ assertEquals(spVec3, spVec1.pow(2));
+ assertEquals(spVec2, spVec1.pow(0));
+ }
+
+ @Test
+ public void testAbs() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+ DoubleVector spVec3 = new SparseDoubleVector(10, -1.5);
+
+ assertEquals(spVec1, spVec3.abs());
+ assertEquals(spVec2, spVec2.abs());
+ }
+
+ @Test
+ public void testSqrt() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 2.25);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec4 = new SparseDoubleVector(10, 1);
+
+ assertEquals(spVec3, spVec1.sqrt());
+ assertEquals(spVec2, spVec2.sqrt());
+ assertEquals(spVec4, spVec4.sqrt());
+ }
+
+ @Test
+ public void testSum() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 2.25);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 1.5);
+
+ assertEquals(22.5, spVec1.sum(), 0.00001);
+ assertEquals(0, spVec2.sum(), 0.000001);
+ assertEquals(15, spVec3.sum(), 0.000001);
+ }
+
+ @Test
+ public void testDot() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 2.25);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+ DoubleVector spVec3 = new SparseDoubleVector(10, 1.5);
+ DoubleVector spVec4 = new SparseDoubleVector(10, 1);
+
+ assertEquals(spVec1.multiply(spVec3).sum(), spVec1.dot(spVec3), 0.000001);
+ assertEquals(spVec3.sum(), spVec3.dot(spVec4), 0.000001);
+ assertEquals(0, spVec1.dot(spVec2), 0.000001);
+ }
+
+ @Test
+ public void testSlice() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 2.25);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+ DoubleVector spVec3 = new SparseDoubleVector(5, 2.25);
+ DoubleVector spVec4 = new SparseDoubleVector(5, 0);
+
+ spVec1.set(7, 100);
+ spVec2.set(2, 200);
+
+ assertEquals(spVec3, spVec1.sliceUnsafe(5));
+ assertFalse(spVec4.equals(spVec2.slice(5)));
+
+ assertFalse(spVec3.equals(spVec1.slice(5, 9)));
+ assertEquals(spVec4, spVec2.slice(5, 9));
+ }
+
+ @Test
+ public void testMaxMin() {
+ DoubleVector spVec1 = new SparseDoubleVector(10, 2.25);
+ DoubleVector spVec2 = new SparseDoubleVector(10, 0);
+
+ spVec1.set(7, 100);
+ spVec2.set(2, 200);
+
+ assertEquals(100, spVec1.max(), 0.000001);
+ assertEquals(0, spVec2.min(), 0.000001);
+ }
+
+}