You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/09/08 23:35:13 UTC
[09/15] mahout git commit: NO-JIRA Trevors updates
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
new file mode 100644
index 0000000..ee54ad0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
@@ -0,0 +1,289 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
+
+/**
+ * sparse matrix with general element values whose rows are accessible quickly. Implemented as a row
+ * array of either SequentialAccessSparseVectors or RandomAccessSparseVectors.
+ */
+public class SparseRowMatrix extends AbstractMatrix {
+ private Vector[] rowVectors;
+
+ private final boolean randomAccessRows;
+
+ private static final Logger log = LoggerFactory.getLogger(SparseRowMatrix.class);
+
+ /**
+ * Construct a sparse matrix starting with the provided row vectors.
+ *
+ * @param rows The number of rows in the result
+ * @param columns The number of columns in the result
+ * @param rowVectors a Vector[] array of rows
+ */
+ public SparseRowMatrix(int rows, int columns, Vector[] rowVectors) {
+ this(rows, columns, rowVectors, false, rowVectors instanceof RandomAccessSparseVector[]);
+ }
+
+ public SparseRowMatrix(int rows, int columns, boolean randomAccess) {
+ this(rows, columns, randomAccess
+ ? new RandomAccessSparseVector[rows]
+ : new SequentialAccessSparseVector[rows],
+ true,
+ randomAccess);
+ }
+
+ public SparseRowMatrix(int rows, int columns, Vector[] vectors, boolean shallowCopy, boolean randomAccess) {
+ super(rows, columns);
+ this.randomAccessRows = randomAccess;
+ this.rowVectors = vectors.clone();
+ for (int row = 0; row < rows; row++) {
+ if (vectors[row] == null) {
+ // TODO: this can't be right to change the argument
+ vectors[row] = randomAccess
+ ? new RandomAccessSparseVector(numCols(), 10)
+ : new SequentialAccessSparseVector(numCols(), 10);
+ }
+ this.rowVectors[row] = shallowCopy ? vectors[row] : vectors[row].clone();
+ }
+ }
+
+ /**
+ * Construct a matrix of the given cardinality, with rows defaulting to RandomAccessSparseVector
+ * implementation
+ *
+ * @param rows Number of rows in result
+ * @param columns Number of columns in result
+ */
+ public SparseRowMatrix(int rows, int columns) {
+ this(rows, columns, true);
+ }
+
+ @Override
+ public Matrix clone() {
+ SparseRowMatrix clone = (SparseRowMatrix) super.clone();
+ clone.rowVectors = new Vector[rowVectors.length];
+ for (int i = 0; i < rowVectors.length; i++) {
+ clone.rowVectors[i] = rowVectors[i].clone();
+ }
+ return clone;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return rowVectors[row] == null ? 0.0 : rowVectors[row].getQuick(column);
+ }
+
+ @Override
+ public Matrix like() {
+ return new SparseRowMatrix(rowSize(), columnSize(), randomAccessRows);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new SparseRowMatrix(rows, columns, randomAccessRows);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ rowVectors[row].setQuick(column, value);
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ int[] result = new int[2];
+ result[ROW] = rowVectors.length;
+ for (int row = 0; row < rowSize(); row++) {
+ result[COL] = Math.max(result[COL], rowVectors[row].getNumNondefaultElements());
+ }
+ return result;
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ if (offset[ROW] < 0) {
+ throw new IndexException(offset[ROW], rowVectors.length);
+ }
+ if (offset[ROW] + size[ROW] > rowVectors.length) {
+ throw new IndexException(offset[ROW] + size[ROW], rowVectors.length);
+ }
+ if (offset[COL] < 0) {
+ throw new IndexException(offset[COL], rowVectors[ROW].size());
+ }
+ if (offset[COL] + size[COL] > rowVectors[ROW].size()) {
+ throw new IndexException(offset[COL] + size[COL], rowVectors[ROW].size());
+ }
+ return new MatrixView(this, offset, size);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ int rows = rowSize();
+ if (rows != other.rowSize()) {
+ throw new CardinalityException(rows, other.rowSize());
+ }
+ int columns = columnSize();
+ if (columns != other.columnSize()) {
+ throw new CardinalityException(columns, other.columnSize());
+ }
+ for (int row = 0; row < rows; row++) {
+ try {
+ Iterator<Vector.Element> sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row])
+ .iterateNonZero();
+ if (function.isLikeMult()) { // TODO: is this a sufficient test?
+ // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be
+ // a SequentialAccessSparseVector, should "try" here just in case and Warn
+ // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup?
+
+ // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal
+ // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which
+ // are backed by fastutil hashmaps and the iterateNonZero actually does only return nonZeros.
+ while (sparseRowIterator.hasNext()) {
+ Vector.Element element = sparseRowIterator.next();
+ int col = element.index();
+ setQuick(row, col, function.apply(element.get(), other.getQuick(row, col)));
+ }
+ } else {
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+ }
+ }
+
+ } catch (ClassCastException e) {
+ // Warn and use default implementation
+ log.warn("Error casting the row to SequentialAccessSparseVector, this should never happen because" +
+ "SparseRomMatrix is always made of SequentialAccessSparseVectors. Proceeding with non-optimzed" +
+ "implementation.");
+ for (int col = 0; col < columns; col++) {
+ setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+ }
+ }
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (rowSize() != other.size()) {
+ throw new CardinalityException(rowSize(), other.size());
+ }
+ if (column < 0 || column >= columnSize()) {
+ throw new IndexException(column, columnSize());
+ }
+ for (int row = 0; row < rowSize(); row++) {
+ rowVectors[row].setQuick(column, other.getQuick(row));
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new CardinalityException(columnSize(), other.size());
+ }
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ rowVectors[row].assign(other);
+ return this;
+ }
+
+ /**
+ * @param row an int row index
+ * @return a shallow view of the Vector at specified row (ie you may mutate the original matrix
+ * using this row)
+ */
+ @Override
+ public Vector viewRow(int row) {
+ if (row < 0 || row >= rowSize()) {
+ throw new IndexException(row, rowSize());
+ }
+ return rowVectors[row];
+ }
+
+ @Override
+ public Matrix transpose() {
+ SparseColumnMatrix scm = new SparseColumnMatrix(columns, rows);
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ if (row.getNumNonZeroElements() > 0) {
+ scm.assignColumn(i, row);
+ }
+ }
+ return scm;
+ }
+
+ @Override
+ public Matrix times(Matrix other) {
+ if (columnSize() != other.rowSize()) {
+ throw new CardinalityException(columnSize(), other.rowSize());
+ }
+
+ if (other instanceof SparseRowMatrix) {
+ SparseRowMatrix y = (SparseRowMatrix) other;
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(y.rowVectors[element.index()], Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ } else {
+ if (other.viewRow(0).isDense()) {
+ // result is dense, but can be computed relatively cheaply
+ Matrix result = other.like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ Vector r = new DenseVector(other.columnSize());
+ for (Vector.Element element : row.nonZeroes()) {
+ r.assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ result.viewRow(i).assign(r);
+ }
+ return result;
+ } else {
+ // other is sparse, but not something we understand intimately
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ }
+ }
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.SPARSELIKE;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Swapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Swapper.java b/core/src/main/java/org/apache/mahout/math/Swapper.java
new file mode 100644
index 0000000..1ca3744
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Swapper.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear in all copies and
+that both that copyright notice and this permission notice appear in supporting documentation.
+CERN makes no representations about the suitability of this software for any purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math;
+
+/**
+ * Interface for an object that knows how to swap elements at two positions (a,b).
+ */
+public interface Swapper {
+
+ /** Swaps the generic data g[a] with g[b]. */
+ void swap(int a, int b);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java b/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
new file mode 100644
index 0000000..ede6f35
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Matrix View backed by an {@link org.apache.mahout.math.function.IntIntFunction}
+ */
+public class TransposedMatrixView extends AbstractMatrix {
+
+ private Matrix m;
+
+ public TransposedMatrixView(Matrix m) {
+ super(m.numCols(), m.numRows());
+ this.m = m;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ m.assignRow(column,other);
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ m.assignColumn(row,other);
+ return this;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return m.getQuick(column,row);
+ }
+
+ @Override
+ public Matrix like() {
+ return m.like(rows, columns);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return m.like(rows,columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ m.setQuick(column, row, value);
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ return m.viewColumn(row);
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ return m.viewRow(column);
+ }
+
+ @Override
+ public Matrix assign(double value) {
+ return m.assign(value);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ if (other instanceof TransposedMatrixView) {
+ m.assign(((TransposedMatrixView) other).m, function);
+ } else {
+ m.assign(new TransposedMatrixView(other), function);
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(Matrix other) {
+ if (other instanceof TransposedMatrixView) {
+ return m.assign(((TransposedMatrixView) other).m);
+ } else {
+ return m.assign(new TransposedMatrixView(other));
+ }
+ }
+
+ @Override
+ public Matrix assign(DoubleFunction function) {
+ return m.assign(function);
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return flavor;
+ }
+
+ private MatrixFlavor flavor = new MatrixFlavor() {
+ @Override
+ public BackEnum getBacking() {
+ return m.getFlavor().getBacking();
+ }
+
+ @Override
+ public TraversingStructureEnum getStructure() {
+ TraversingStructureEnum flavor = m.getFlavor().getStructure();
+ switch (flavor) {
+ case COLWISE:
+ return TraversingStructureEnum.ROWWISE;
+ case SPARSECOLWISE:
+ return TraversingStructureEnum.SPARSEROWWISE;
+ case ROWWISE:
+ return TraversingStructureEnum.COLWISE;
+ case SPARSEROWWISE:
+ return TraversingStructureEnum.SPARSECOLWISE;
+ default:
+ return flavor;
+ }
+ }
+
+ @Override
+ public boolean isDense() {
+ return m.getFlavor().isDense();
+ }
+ };
+
+ Matrix getDelegate() {
+ return m;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/UpperTriangular.java b/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
new file mode 100644
index 0000000..29fa6a0
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/UpperTriangular.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
+/**
+ *
+ * Quick and dirty implementation of some {@link org.apache.mahout.math.Matrix} methods
+ * over packed upper triangular matrix.
+ *
+ */
+public class UpperTriangular extends AbstractMatrix {
+
+ private static final double EPSILON = 1.0e-12; // assume anything less than
+ // that to be 0 during
+ // non-upper assignments
+
+ private double[] values;
+
+ /**
+ * represents n x n upper triangular matrix
+ *
+ * @param n
+ */
+
+ public UpperTriangular(int n) {
+ super(n, n);
+ values = new double[n * (n + 1) / 2];
+ }
+
+ public UpperTriangular(double[] data, boolean shallow) {
+ this(elementsToMatrixSize(data != null ? data.length : 0));
+ if (data == null) {
+ throw new IllegalArgumentException("data");
+ }
+ values = shallow ? data : data.clone();
+ }
+
+ public UpperTriangular(Vector data) {
+ this(elementsToMatrixSize(data.size()));
+
+ for (Vector.Element el:data.nonZeroes()) {
+ values[el.index()] = el.get();
+ }
+ }
+
+ private static int elementsToMatrixSize(int dataSize) {
+ return (int) Math.round((-1 + Math.sqrt(1 + 8 * dataSize)) / 2);
+ }
+
+ // copy-constructor
+ public UpperTriangular(UpperTriangular mx) {
+ this(mx.values, false);
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new IndexException(columnSize(), other.size());
+ }
+ if (other.viewPart(column + 1, other.size() - column - 1).norm(1) > 1.0e-14) {
+ throw new IllegalArgumentException("Cannot set lower portion of triangular matrix to non-zero");
+ }
+ for (Vector.Element element : other.viewPart(0, column).all()) {
+ setQuick(element.index(), column, element.get());
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ if (columnSize() != other.size()) {
+ throw new IndexException(numCols(), other.size());
+ }
+ for (int i = 0; i < row; i++) {
+ if (Math.abs(other.getQuick(i)) > EPSILON) {
+ throw new IllegalArgumentException("non-triangular source");
+ }
+ }
+ for (int i = row; i < rows; i++) {
+ setQuick(row, i, other.get(i));
+ }
+ return this;
+ }
+
+ public Matrix assignNonZeroElementsInRow(int row, double[] other) {
+ System.arraycopy(other, row, values, getL(row, row), rows - row);
+ return this;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ if (row > column) {
+ return 0;
+ }
+ int i = getL(row, column);
+ return values[i];
+ }
+
+ private int getL(int row, int col) {
+ /*
+ * each row starts with some zero elements that we don't store. this
+ * accumulates an offset of (row+1)*row/2
+ */
+ return col + row * numCols() - (row + 1) * row / 2;
+ }
+
+ @Override
+ public Matrix like() {
+ return like(rowSize(), columnSize());
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return new DenseMatrix(rows, columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ values[getL(row, column)] = value;
+ }
+
+ @Override
+ public int[] getNumNondefaultElements() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Matrix viewPart(int[] offset, int[] size) {
+ return new MatrixView(this, offset, size);
+ }
+
+ public double[] getData() {
+ return values;
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ // We kind of consider ourselves a vector-backed but dense matrix for mmul, etc. purposes.
+ return new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, true);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Vector.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/Vector.java b/core/src/main/java/org/apache/mahout/math/Vector.java
new file mode 100644
index 0000000..c3b1dc9
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/Vector.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * The basic interface including numerous convenience functions <p> NOTE: All implementing classes must have a
+ * constructor that takes an int for cardinality and a no-arg constructor that can be used for marshalling the Writable
+ * instance <p> NOTE: Implementations may choose to reuse the Vector.Element in the Iterable methods
+ */
+public interface Vector extends Cloneable {
+
+ /** @return a formatted String suitable for output */
+ String asFormatString();
+
+ /**
+ * Assign the value to all elements of the receiver
+ *
+ * @param value a double value
+ * @return the modified receiver
+ */
+ Vector assign(double value);
+
+ /**
+ * Assign the values to the receiver
+ *
+ * @param values a double[] of values
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(double[] values);
+
+ /**
+ * Assign the other vector values to the receiver
+ *
+ * @param other a Vector
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(Vector other);
+
+ /**
+ * Apply the function to each element of the receiver
+ *
+ * @param function a DoubleFunction to apply
+ * @return the modified receiver
+ */
+ Vector assign(DoubleFunction function);
+
+ /**
+ * Apply the function to each element of the receiver and the corresponding element of the other argument
+ *
+ * @param other a Vector containing the second arguments to the function
+ * @param function a DoubleDoubleFunction to apply
+ * @return the modified receiver
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector assign(Vector other, DoubleDoubleFunction function);
+
+ /**
+ * Apply the function to each element of the receiver, using the y value as the second argument of the
+ * DoubleDoubleFunction
+ *
+ * @param f a DoubleDoubleFunction to be applied
+ * @param y a double value to be argument to the function
+ * @return the modified receiver
+ */
+ Vector assign(DoubleDoubleFunction f, double y);
+
+ /**
+ * Return the cardinality of the recipient (the maximum number of values)
+ *
+ * @return an int
+ */
+ int size();
+
+ /**
+ * true if this implementation should be considered dense -- that it explicitly
+ * represents every value
+ *
+ * @return true or false
+ */
+ boolean isDense();
+
+ /**
+ * true if this implementation should be considered to be iterable in index order in an efficient way.
+ * In particular this implies that {@link #all()} and {@link #nonZeroes()} ()} return elements
+ * in ascending order by index.
+ *
+ * @return true iff this implementation should be considered to be iterable in index order in an efficient way.
+ */
+ boolean isSequentialAccess();
+
+ /**
+ * Return a copy of the recipient
+ *
+ * @return a new Vector
+ */
+ @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
+ Vector clone();
+
+ Iterable<Element> all();
+
+ Iterable<Element> nonZeroes();
+
+ /**
+ * Return an object of Vector.Element representing an element of this Vector. Useful when designing new iterator
+ * types.
+ *
+ * @param index Index of the Vector.Element required
+ * @return The Vector.Element Object
+ */
+ Element getElement(int index);
+
+ /**
+ * Merge a set of (index, value) pairs into the vector.
+ * @param updates an ordered mapping of indices to values to be merged in.
+ */
+ void mergeUpdates(OrderedIntDoubleMapping updates);
+
+ /**
+ * A holder for information about a specific item in the Vector. <p>
+ * When using with an Iterator, the implementation
+ * may choose to reuse this element, so you may need to make a copy if you want to keep it
+ */
+ interface Element {
+
+ /** @return the value of this vector element. */
+ double get();
+
+ /** @return the index of this vector element. */
+ int index();
+
+ /** @param value Set the current element to value. */
+ void set(double value);
+ }
+
+ /**
+ * Return a new vector containing the values of the recipient divided by the argument
+ *
+ * @param x a double value
+ * @return a new Vector
+ */
+ Vector divide(double x);
+
+ /**
+ * Return the dot product of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ double dot(Vector x);
+
+ /**
+ * Return the value at the given index
+ *
+ * @param index an int index
+ * @return the double at the index
+ * @throws IndexException if the index is out of bounds
+ */
+ double get(int index);
+
+ /**
+ * Return the value at the given index, without checking bounds
+ *
+ * @param index an int index
+ * @return the double at the index
+ */
+ double getQuick(int index);
+
+ /**
+ * Return an empty vector of the same underlying class as the receiver
+ *
+ * @return a Vector
+ */
+ Vector like();
+
+ /**
+ * Return a new empty vector of the same underlying class as the receiver with given cardinality
+ *
+ * @param cardinality - size of vector
+ * @return {@link Vector}
+ */
+ Vector like(int cardinality);
+
+ /**
+ * Return a new vector containing the element by element difference of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector minus(Vector x);
+
+ /**
+ * Return a new vector containing the normalized (L_2 norm) values of the recipient
+ *
+ * @return a new Vector
+ */
+ Vector normalize();
+
+ /**
+ * Return a new Vector containing the normalized (L_power norm) values of the recipient. <p>
+ * See
+ * http://en.wikipedia.org/wiki/Lp_space <p>
+ * Technically, when {@code 0 < power < 1}, we don't have a norm, just a metric,
+ * but we'll overload this here. <p>
+ * Also supports {@code power == 0} (number of non-zero elements) and power = {@link
+ * Double#POSITIVE_INFINITY} (max element). Again, see the Wikipedia page for more info
+ *
+ * @param power The power to use. Must be >= 0. May also be {@link Double#POSITIVE_INFINITY}. See the Wikipedia link
+ * for more on this.
+ * @return a new Vector x such that norm(x, power) == 1
+ */
+ Vector normalize(double power);
+
+ /**
+ * Return a new vector containing the log(1 + entry)/ L_2 norm values of the recipient
+ *
+ * @return a new Vector
+ */
+ Vector logNormalize();
+
+ /**
+ * Return a new Vector with a normalized value calculated as log_power(1 + entry)/ L_power norm. <p>
+ *
+ * @param power The power to use. Must be > 1. Cannot be {@link Double#POSITIVE_INFINITY}.
+ * @return a new Vector
+ */
+ Vector logNormalize(double power);
+
+ /**
+ * Return the k-norm of the vector. <p/> See http://en.wikipedia.org/wiki/Lp_space <p>
+ * Technically, when {@code 0 > power < 1}, we don't have a norm, just a metric, but we'll overload this here. Also supports power == 0 (number of
+ * non-zero elements) and power = {@link Double#POSITIVE_INFINITY} (max element). Again, see the Wikipedia page for
+ * more info.
+ *
+ * @param power The power to use.
+ * @see #normalize(double)
+ */
+ double norm(double power);
+
+ /** @return The minimum value in the Vector */
+ double minValue();
+
+ /** @return The index of the minimum value */
+ int minValueIndex();
+
+ /** @return The maximum value in the Vector */
+ double maxValue();
+
+ /** @return The index of the maximum value */
+ int maxValueIndex();
+
+ /**
+ * Return a new vector containing the sum of each value of the recipient and the argument
+ *
+ * @param x a double
+ * @return a new Vector
+ */
+ Vector plus(double x);
+
+ /**
+ * Return a new vector containing the element by element sum of the recipient and the argument
+ *
+ * @param x a Vector
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector plus(Vector x);
+
+ /**
+ * Set the value at the given index
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ * @throws IndexException if the index is out of bounds
+ */
+ void set(int index, double value);
+
+ /**
+ * Set the value at the given index, without checking bounds
+ *
+ * @param index an int index into the receiver
+ * @param value a double value to set
+ */
+ void setQuick(int index, double value);
+
+ /**
+ * Increment the value at the given index by the given value.
+ *
+ * @param index an int index into the receiver
+ * @param increment sets the value at the given index to value + increment;
+ */
+ void incrementQuick(int index, double increment);
+
+ /**
+ * Return the number of values in the recipient which are not the default value. For instance, for a
+ * sparse vector, this would be the number of non-zero values.
+ *
+ * @return an int
+ */
+ int getNumNondefaultElements();
+
+ /**
+ * Return the number of non zero elements in the vector.
+ *
+ * @return an int
+ */
+ int getNumNonZeroElements();
+
+ /**
+ * Return a new vector containing the product of each value of the recipient and the argument
+ *
+ * @param x a double argument
+ * @return a new Vector
+ */
+ Vector times(double x);
+
+ /**
+ * Return a new vector containing the element-wise product of the recipient and the argument
+ *
+ * @param x a Vector argument
+ * @return a new Vector
+ * @throws CardinalityException if the cardinalities differ
+ */
+ Vector times(Vector x);
+
+ /**
+ * Return a new vector containing the subset of the recipient
+ *
+ * @param offset an int offset into the receiver
+ * @param length the cardinality of the desired result
+ * @return a new Vector
+ * @throws CardinalityException if the length is greater than the cardinality of the receiver
+ * @throws IndexException if the offset is negative or the offset+length is outside of the receiver
+ */
+ Vector viewPart(int offset, int length);
+
+ /**
+ * Return the sum of all the elements of the receiver
+ *
+ * @return a double
+ */
+ double zSum();
+
+ /**
+ * Return the cross product of the receiver and the other vector
+ *
+ * @param other another Vector
+ * @return a Matrix
+ */
+ Matrix cross(Vector other);
+
+ /*
+ * Need stories for these but keeping them here for now.
+ */
+ // void getNonZeros(IntArrayList jx, DoubleArrayList values);
+ // void foreachNonZero(IntDoubleFunction f);
+ // DoubleDoubleFunction map);
+ // NewVector assign(Vector y, DoubleDoubleFunction function, IntArrayList
+ // nonZeroIndexes);
+
+ /**
+ * Examples speak louder than words: aggregate(plus, pow(2)) is another way to say
+ * getLengthSquared(), aggregate(max, abs) is norm(Double.POSITIVE_INFINITY). To sum all of the positive values,
+ * aggregate(plus, max(0)).
+ * @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue)
+ * @param map a function to apply to each element of the vector in turn before passing to the aggregator
+ * @return the final aggregation
+ */
+ double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map);
+
+ /**
+ * <p>Generalized inner product - take two vectors, iterate over them both, using the combiner to combine together
+ * (and possibly map in some way) each pair of values, which are then aggregated with the previous accumulated
+ * value in the combiner.</p>
+ * <p>
+ * Example: dot(other) could be expressed as aggregate(other, Plus, Times), and kernelized inner products (which
+ * are symmetric on the indices) work similarly.
+ * @param other a vector to aggregate in combination with
+ * @param aggregator function we're aggregating with; fa
+ * @param combiner function we're combining with; fc
+ * @return the final aggregation; {@code if r0 = fc(this[0], other[0]), ri = fa(r_{i-1}, fc(this[i], other[i]))
+ * for all i > 0}
+ */
+ double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner);
+
+ /**
+ * Return the sum of squares of all elements in the vector. Square root of
+ * this value is the length of the vector.
+ */
+ double getLengthSquared();
+
+ /**
+ * Get the square of the distance between this vector and the other vector.
+ */
+ double getDistanceSquared(Vector v);
+
+ /**
+ * Gets an estimate of the cost (in number of operations) it takes to lookup a random element in this vector.
+ */
+ double getLookupCost();
+
+ /**
+ * Gets an estimate of the cost (in number of operations) it takes to advance an iterator through the nonzero
+ * elements of this vector.
+ */
+ double getIteratorAdvanceCost();
+
+ /**
+ * Return true iff adding a new (nonzero) element takes constant time for this vector.
+ */
+ boolean isAddConstantTime();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java b/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
new file mode 100644
index 0000000..4d3a80f
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java
@@ -0,0 +1,481 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.util.Iterator;
+
+/**
+ * Abstract class encapsulating different algorithms that perform the Vector operations aggregate().
+ * x.aggregte(y, fa, fc), for x and y Vectors and fa, fc DoubleDouble functions:
+ * - applies the function fc to every element in x and y, fc(xi, yi)
+ * - constructs a result iteratively, r0 = fc(x0, y0), ri = fc(r_{i-1}, fc(xi, yi)).
+ * This works essentially like a map/reduce functional combo.
+ *
+ * The names of variables, methods and classes used here follow the following conventions:
+ * The vector being assigned to (the left hand side) is called this or x.
+ * The right hand side is called that or y.
+ * The aggregating (reducing) function to be applied is called fa.
+ * The combining (mapping) function to be applied is called fc.
+ *
+ * The different algorithms take into account the different characteristics of vector classes:
+ * - whether the vectors support sequential iteration (isSequential())
+ * - what the lookup cost is (getLookupCost())
+ * - what the iterator advancement cost is (getIteratorAdvanceCost())
+ *
+ * The names of the actual classes (they're nested in VectorBinaryAssign) describe the used for assignment.
+ * The most important optimization is iterating just through the nonzeros (only possible if f(0, 0) = 0).
+ * There are 4 main possibilities:
+ * - iterating through the nonzeros of just one vector and looking up the corresponding elements in the other
+ * - iterating through the intersection of nonzeros (those indices where both vectors have nonzero values)
+ * - iterating through the union of nonzeros (those indices where at least one of the vectors has a nonzero value)
+ * - iterating through all the elements in some way (either through both at the same time, both one after the other,
+ * looking up both, looking up just one).
+ *
+ * The internal details are not important and a particular algorithm should generally not be called explicitly.
+ * The best one will be selected through assignBest(), which is itself called through Vector.assign().
+ *
+ * See https://docs.google.com/document/d/1g1PjUuvjyh2LBdq2_rKLIcUiDbeOORA1sCJiSsz-JVU/edit# for a more detailed
+ * explanation.
+ */
+public abstract class VectorBinaryAggregate {
+ public static final VectorBinaryAggregate[] OPERATIONS = {
+ new AggregateNonzerosIterateThisLookupThat(),
+ new AggregateNonzerosIterateThatLookupThis(),
+
+ new AggregateIterateIntersection(),
+
+ new AggregateIterateUnionSequential(),
+ new AggregateIterateUnionRandom(),
+
+ new AggregateAllIterateSequential(),
+ new AggregateAllIterateThisLookupThat(),
+ new AggregateAllIterateThatLookupThis(),
+ new AggregateAllLoop(),
+ };
+
+ /**
+ * Returns true iff we can use this algorithm to apply fc to x and y component-wise and aggregate the result using fa.
+ */
+ public abstract boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * Estimates the cost of using this algorithm to compute the aggregation. The algorithm is assumed to be valid.
+ */
+ public abstract double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * Main method that applies fc to x and y component-wise aggregating the results with fa. It returns the result of
+ * the aggregation.
+ */
+ public abstract double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc);
+
+ /**
+ * The best operation is the least expensive valid one.
+ */
+ public static VectorBinaryAggregate getBestOperation(Vector x, Vector y, DoubleDoubleFunction fa,
+ DoubleDoubleFunction fc) {
+ int bestOperationIndex = -1;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < OPERATIONS.length; ++i) {
+ if (OPERATIONS[i].isValid(x, y, fa, fc)) {
+ double cost = OPERATIONS[i].estimateCost(x, y, fa, fc);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestOperationIndex = i;
+ }
+ }
+ }
+ return OPERATIONS[bestOperationIndex];
+ }
+
+ /**
+ * This is the method that should be used when aggregating. It selects the best algorithm and applies it.
+ */
+ public static double aggregateBest(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return getBestOperation(x, y, fa, fc).aggregate(x, y, fa, fc);
+ }
+
+ public static class AggregateNonzerosIterateThisLookupThat extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && (fa.isAssociativeAndCommutative() || x.isSequentialAccess())
+ && fc.isLikeLeftMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ if (!xi.hasNext()) {
+ return 0;
+ }
+ Vector.Element xe = xi.next();
+ double result = fc.apply(xe.get(), y.getQuick(xe.index()));
+ while (xi.hasNext()) {
+ xe = xi.next();
+ result = fa.apply(result, fc.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateNonzerosIterateThatLookupThis extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && (fa.isAssociativeAndCommutative() || y.isSequentialAccess())
+ && fc.isLikeRightMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost() * x.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ if (!yi.hasNext()) {
+ return 0;
+ }
+ Vector.Element ye = yi.next();
+ double result = fc.apply(x.getQuick(ye.index()), ye.get());
+ while (yi.hasNext()) {
+ ye = yi.next();
+ result = fa.apply(result, fc.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateIntersection extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && fc.isLikeMult() && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.min(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ boolean validResult = false;
+ double result = 0;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ break;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ break;
+ }
+ }
+ if (xe.index() == ye.index()) {
+ double thisResult = fc.apply(xe.get(), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0) = 0
+ advanceThis = true;
+ advanceThat = false;
+ } else { // f(0, y) = 0
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateUnionSequential extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && !fc.isDensifying()
+ && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ boolean validResult = false;
+ double result = 0;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ double thisResult;
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ thisResult = fc.apply(xe.get(), ye.get());
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ thisResult = fc.apply(xe.get(), 0);
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ thisResult = fc.apply(0, ye.get());
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ thisResult = fc.apply(xe.get(), 0);
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ thisResult = fc.apply(0, ye.get());
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateIterateUnionRandom extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return fa.isLikeRightPlus() && !fc.isDensifying()
+ && (fa.isAssociativeAndCommutative() || (x.isSequentialAccess() && y.isSequentialAccess()));
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ boolean validResult = false;
+ double result = 0;
+ double thisResult;
+ while (xi.hasNext()) {
+ Vector.Element xe = xi.next();
+ thisResult = fc.apply(xe.get(), y.getQuick(xe.index()));
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ visited.add(xe.index());
+ }
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ while (yi.hasNext()) {
+ Vector.Element ye = yi.next();
+ if (!visited.contains(ye.index())) {
+ thisResult = fc.apply(x.getQuick(ye.index()), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateSequential extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (xi.hasNext() && yi.hasNext()) {
+ Vector.Element xe = xi.next();
+ double thisResult = fc.apply(xe.get(), yi.next().get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateThisLookupThat extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return (fa.isAssociativeAndCommutative() || x.isSequentialAccess())
+ && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (xi.hasNext()) {
+ Vector.Element xe = xi.next();
+ double thisResult = fc.apply(xe.get(), y.getQuick(xe.index()));
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllIterateThatLookupThis extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return (fa.isAssociativeAndCommutative() || y.isSequentialAccess())
+ && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ Iterator<Vector.Element> yi = y.all().iterator();
+ boolean validResult = false;
+ double result = 0;
+ while (yi.hasNext()) {
+ Vector.Element ye = yi.next();
+ double thisResult = fc.apply(x.getQuick(ye.index()), ye.get());
+ if (validResult) {
+ result = fa.apply(result, thisResult);
+ } else {
+ result = thisResult;
+ validResult = true;
+ }
+ }
+ return result;
+ }
+ }
+
+ public static class AggregateAllLoop extends VectorBinaryAggregate {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return true;
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) {
+ double result = fc.apply(x.getQuick(0), y.getQuick(0));
+ int s = x.size();
+ for (int i = 1; i < s; ++i) {
+ result = fa.apply(result, fc.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ return result;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java b/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
new file mode 100644
index 0000000..f24d552
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorBinaryAssign.java
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.util.Iterator;
+
+/**
+ * Abstract class encapsulating different algorithms that perform the Vector operations assign().
+ * x.assign(y, f), for x and y Vectors and f a DoubleDouble function:
+ * - applies the function f to every element in x and y, f(xi, yi)
+ * - assigns xi = f(xi, yi) for all indices i
+ *
+ * The names of variables, methods and classes used here follow the following conventions:
+ * The vector being assigned to (the left hand side) is called this or x.
+ * The right hand side is called that or y.
+ * The function to be applied is called f.
+ *
+ * The different algorithms take into account the different characteristics of vector classes:
+ * - whether the vectors support sequential iteration (isSequential())
+ * - whether the vectors support constant-time additions (isAddConstantTime())
+ * - what the lookup cost is (getLookupCost())
+ * - what the iterator advancement cost is (getIteratorAdvanceCost())
+ *
+ * The names of the actual classes (they're nested in VectorBinaryAssign) describe the used for assignment.
+ * The most important optimization is iterating just through the nonzeros (only possible if f(0, 0) = 0).
+ * There are 4 main possibilities:
+ * - iterating through the nonzeros of just one vector and looking up the corresponding elements in the other
+ * - iterating through the intersection of nonzeros (those indices where both vectors have nonzero values)
+ * - iterating through the union of nonzeros (those indices where at least one of the vectors has a nonzero value)
+ * - iterating through all the elements in some way (either through both at the same time, both one after the other,
+ * looking up both, looking up just one).
+ * Then, there are two additional sub-possibilities:
+ * - if a new value can be added to x in constant time (isAddConstantTime()), the *Inplace updates are used
+ * - otherwise (really just for SequentialAccessSparseVectors right now), the *Merge updates are used, where
+ * a sorted list of (index, value) pairs is merged into the vector at the end.
+ *
+ * The internal details are not important and a particular algorithm should generally not be called explicitly.
+ * The best one will be selected through assignBest(), which is itself called through Vector.assign().
+ *
+ * See https://docs.google.com/document/d/1g1PjUuvjyh2LBdq2_rKLIcUiDbeOORA1sCJiSsz-JVU/edit# for a more detailed
+ * explanation.
+ */
+public abstract class VectorBinaryAssign {
+ public static final VectorBinaryAssign[] OPERATIONS = {
+ new AssignNonzerosIterateThisLookupThat(),
+ new AssignNonzerosIterateThatLookupThisMergeUpdates(),
+ new AssignNonzerosIterateThatLookupThisInplaceUpdates(),
+
+ new AssignIterateIntersection(),
+
+ new AssignIterateUnionSequentialMergeUpdates(),
+ new AssignIterateUnionSequentialInplaceUpdates(),
+ new AssignIterateUnionRandomMergeUpdates(),
+ new AssignIterateUnionRandomInplaceUpdates(),
+
+ new AssignAllIterateSequentialMergeUpdates(),
+ new AssignAllIterateSequentialInplaceUpdates(),
+ new AssignAllIterateThisLookupThatMergeUpdates(),
+ new AssignAllIterateThisLookupThatInplaceUpdates(),
+ new AssignAllIterateThatLookupThisMergeUpdates(),
+ new AssignAllIterateThatLookupThisInplaceUpdates(),
+ new AssignAllLoopMergeUpdates(),
+ new AssignAllLoopInplaceUpdates(),
+ };
+
+ /**
+ * Returns true iff we can use this algorithm to apply f to x and y component-wise and assign the result to x.
+ */
+ public abstract boolean isValid(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * Estimates the cost of using this algorithm to compute the assignment. The algorithm is assumed to be valid.
+ */
+ public abstract double estimateCost(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * Main method that applies f to x and y component-wise assigning the results to x. It returns the modified vector,
+ * x.
+ */
+ public abstract Vector assign(Vector x, Vector y, DoubleDoubleFunction f);
+
+ /**
+ * The best operation is the least expensive valid one.
+ */
+ public static VectorBinaryAssign getBestOperation(Vector x, Vector y, DoubleDoubleFunction f) {
+ int bestOperationIndex = -1;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < OPERATIONS.length; ++i) {
+ if (OPERATIONS[i].isValid(x, y, f)) {
+ double cost = OPERATIONS[i].estimateCost(x, y, f);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestOperationIndex = i;
+ }
+ }
+ }
+ return OPERATIONS[bestOperationIndex];
+ }
+
+ /**
+ * This is the method that should be used when assigning. It selects the best algorithm and applies it.
+ * Note that it does NOT invalidate the cached length of the Vector and should only be used through the wrapprs
+ * in AbstractVector.
+ */
+ public static Vector assignBest(Vector x, Vector y, DoubleDoubleFunction f) {
+ return getBestOperation(x, y, f).assign(x, y, f);
+ }
+
+ /**
+ * If f(0, y) = 0, the zeros in x don't matter and we can simply iterate through the nonzeros of x.
+ * To get the corresponding element of y, we perform a lookup.
+ * There are no *Merge or *Inplace versions because in this case x cannot become more dense because of f, meaning
+ * all changes will occur at indices whose values are already nonzero.
+ */
+ public static class AssignNonzerosIterateThisLookupThat extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeLeftMult();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x, the zeros in y don't matter and we can simply iterate through the nonzeros of y.
+ * We get the corresponding element of x through a lookup and update x inplace.
+ */
+ public static class AssignNonzerosIterateThatLookupThisInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeRightPlus();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element ye : y.nonZeroes()) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x, the zeros in y don't matter and we can simply iterate through the nonzeros of y.
+ * We get the corresponding element of x through a lookup and update x by merging.
+ */
+ public static class AssignNonzerosIterateThatLookupThisMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeRightPlus() && y.isSequentialAccess() && !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.nonZeroes()) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(x, 0) = x and f(0, y) = 0 the zeros in x and y don't matter and we can iterate through the nonzeros
+ * in both x and y.
+ * This is only possible if both x and y support sequential access.
+ */
+ public static class AssignIterateIntersection extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return f.isLikeLeftMult() && f.isLikeRightPlus() && x.isSequentialAccess() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.min(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ break;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ break;
+ }
+ }
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0) = 0
+ advanceThis = true;
+ advanceThat = false;
+ } else { // f(0, y) = 0
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case we iterate through them in parallel and update x by merging. Because we're iterating through
+ * both vectors at the same time, x and y need to support sequential access.
+ */
+ public static class AssignIterateUnionSequentialMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isSequentialAccess() && y.isSequentialAccess() && !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ updates.set(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ updates.set(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case we iterate through them in parallel and update x inplace. Because we're iterating through
+ * both vectors at the same time, x and y need to support sequential access.
+ */
+ public static class AssignIterateUnionSequentialInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isSequentialAccess() && y.isSequentialAccess() && x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.nonZeroes().iterator();
+ Iterator<Vector.Element> yi = y.nonZeroes().iterator();
+ Vector.Element xe = null;
+ Vector.Element ye = null;
+ boolean advanceThis = true;
+ boolean advanceThat = true;
+ while (true) {
+ if (advanceThis) {
+ if (xi.hasNext()) {
+ xe = xi.next();
+ } else {
+ xe = null;
+ }
+ }
+ if (advanceThat) {
+ if (yi.hasNext()) {
+ ye = yi.next();
+ } else {
+ ye = null;
+ }
+ }
+ if (xe != null && ye != null) { // both vectors have nonzero elements
+ if (xe.index() == ye.index()) {
+ xe.set(f.apply(xe.get(), ye.get()));
+ advanceThis = true;
+ advanceThat = true;
+ } else {
+ if (xe.index() < ye.index()) { // f(x, 0)
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else {
+ x.setQuick(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ }
+ }
+ } else if (xe != null) { // just the first one still has nonzeros
+ xe.set(f.apply(xe.get(), 0));
+ advanceThis = true;
+ advanceThat = false;
+ } else if (ye != null) { // just the second one has nonzeros
+ x.setQuick(ye.index(), f.apply(0, ye.get()));
+ advanceThis = false;
+ advanceThat = true;
+ } else { // we're done, both are empty
+ break;
+ }
+ }
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case, we iterate through the nozeros of x and y alternatively (this works even when one of them
+ * doesn't support sequential access). Since we're merging the results into x, when iterating through y, the
+ * order of iteration matters and y must support sequential access.
+ */
+ public static class AssignIterateUnionRandomMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && !x.isAddConstantTime() && y.isSequentialAccess();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ visited.add(xe.index());
+ }
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.nonZeroes()) {
+ if (!visited.contains(ye.index())) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ /**
+ * If f(0, 0) = 0 we can iterate through the nonzeros in either x or y.
+ * In this case, we iterate through the nozeros of x and y alternatively (this works even when one of them
+ * doesn't support sequential access). Because updates to x are inplace, neither x, nor y need to support
+ * sequential access.
+ */
+ public static class AssignIterateUnionRandomInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !f.isDensifying() && x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.getNumNondefaultElements() * x.getIteratorAdvanceCost() * y.getLookupCost(),
+ y.getNumNondefaultElements() * y.getIteratorAdvanceCost() * x.getLookupCost());
+ }
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OpenIntHashSet visited = new OpenIntHashSet();
+ for (Element xe : x.nonZeroes()) {
+ xe.set(f.apply(xe.get(), y.getQuick(xe.index())));
+ visited.add(xe.index());
+ }
+ for (Element ye : y.nonZeroes()) {
+ if (!visited.contains(ye.index())) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateSequentialMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && !x.isAddConstantTime() && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ while (xi.hasNext() && yi.hasNext()) {
+ Element xe = xi.next();
+ updates.set(xe.index(), f.apply(xe.get(), yi.next().get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateSequentialInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isSequentialAccess() && y.isSequentialAccess() && x.isAddConstantTime()
+ && !x.isDense() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return Math.max(x.size() * x.getIteratorAdvanceCost(), y.size() * y.getIteratorAdvanceCost());
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ Iterator<Vector.Element> xi = x.all().iterator();
+ Iterator<Vector.Element> yi = y.all().iterator();
+ while (xi.hasNext() && yi.hasNext()) {
+ Element xe = xi.next();
+ x.setQuick(xe.index(), f.apply(xe.get(), yi.next().get()));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThisLookupThatMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime() && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element xe : x.all()) {
+ updates.set(xe.index(), f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThisLookupThatInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime() && !x.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getIteratorAdvanceCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element xe : x.all()) {
+ x.setQuick(xe.index(), f.apply(xe.get(), y.getQuick(xe.index())));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThatLookupThisMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (Element ye : y.all()) {
+ updates.set(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllIterateThatLookupThisInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime() && !y.isDense();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return y.size() * y.getIteratorAdvanceCost() * x.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (Element ye : y.all()) {
+ x.setQuick(ye.index(), f.apply(x.getQuick(ye.index()), ye.get()));
+ }
+ return x;
+ }
+ }
+
+ public static class AssignAllLoopMergeUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return !x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(false);
+ for (int i = 0; i < x.size(); ++i) {
+ updates.set(i, f.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ x.mergeUpdates(updates);
+ return x;
+ }
+ }
+
+ public static class AssignAllLoopInplaceUpdates extends VectorBinaryAssign {
+
+ @Override
+ public boolean isValid(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.isAddConstantTime();
+ }
+
+ @Override
+ public double estimateCost(Vector x, Vector y, DoubleDoubleFunction f) {
+ return x.size() * x.getLookupCost() * y.getLookupCost();
+ }
+
+ @Override
+ public Vector assign(Vector x, Vector y, DoubleDoubleFunction f) {
+ for (int i = 0; i < x.size(); ++i) {
+ x.setQuick(i, f.apply(x.getQuick(i), y.getQuick(i)));
+ }
+ return x;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorIterable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/mahout/math/VectorIterable.java b/core/src/main/java/org/apache/mahout/math/VectorIterable.java
new file mode 100644
index 0000000..8414fdb
--- /dev/null
+++ b/core/src/main/java/org/apache/mahout/math/VectorIterable.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.util.Iterator;
+
+public interface VectorIterable extends Iterable<MatrixSlice> {
+
+ /* Iterate all rows in order */
+ Iterator<MatrixSlice> iterateAll();
+
+ /* Iterate all non empty rows in arbitrary order */
+ Iterator<MatrixSlice> iterateNonEmpty();
+
+ int numSlices();
+
+ int numRows();
+
+ int numCols();
+
+ /**
+ * Return a new vector with cardinality equal to getNumRows() of this matrix which is the matrix product of the
+ * recipient and the argument
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector)
+ * @throws CardinalityException if this.getNumRows() != v.size()
+ */
+ Vector times(Vector v);
+
+ /**
+ * Convenience method for producing this.transpose().times(this.times(v)), which can be implemented with only one pass
+ * over the matrix, without making the transpose() call (which can be expensive if the matrix is sparse)
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector) with cardinality equal to that of the argument.
+ * @throws CardinalityException if this.getNumCols() != v.size()
+ */
+ Vector timesSquared(Vector v);
+
+}