You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2017/11/18 20:48:36 UTC

mahout git commit: MAHOUT-2019 SparkRow Matrix Speedup and fixing change to scala 2.11 made by build script

Repository: mahout
Updated Branches:
  refs/heads/master d9b32f308 -> 800a9ed6d


MAHOUT-2019 SparkRow Matrix Speedup and fixing change to scala 2.11 made by build script


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/800a9ed6
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/800a9ed6
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/800a9ed6

Branch: refs/heads/master
Commit: 800a9ed6d7e015aa82b9eb7624bb441b71a8f397
Parents: d9b32f3
Author: pferrel <pa...@occamsmachete.com>
Authored: Sat Nov 18 12:29:06 2017 -0800
Committer: pferrel <pa...@occamsmachete.com>
Committed: Sat Nov 18 12:34:07 2017 -0800

----------------------------------------------------------------------
 .../org/apache/mahout/math/SparseRowMatrix.java | 53 ++++++++++++++++++++
 1 file changed, 53 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/800a9ed6/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
index 6e06769..ee54ad0 100644
--- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
@@ -19,7 +19,12 @@ package org.apache.mahout.math;
 
 import org.apache.mahout.math.flavor.MatrixFlavor;
 import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
 import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
 
 /**
  * sparse matrix with general element values whose rows are accessible quickly. Implemented as a row
@@ -30,6 +35,8 @@ public class SparseRowMatrix extends AbstractMatrix {
 
   private final boolean randomAccessRows;
 
+  private static final Logger log = LoggerFactory.getLogger(SparseRowMatrix.class);
+
   /**
    * Construct a sparse matrix starting with the provided row vectors.
    *
@@ -133,6 +140,52 @@ public class SparseRowMatrix extends AbstractMatrix {
   }
 
   @Override
+  public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+    int rows = rowSize();
+    if (rows != other.rowSize()) {
+      throw new CardinalityException(rows, other.rowSize());
+    }
+    int columns = columnSize();
+    if (columns != other.columnSize()) {
+      throw new CardinalityException(columns, other.columnSize());
+    }
+    for (int row = 0; row < rows; row++) {
+      try {
+        Iterator<Vector.Element> sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row])
+                .iterateNonZero();
+        if (function.isLikeMult()) { // TODO: is this a sufficient test?
+          // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be
+          // a SequentialAccessSparseVector, should "try" here just in case and Warn
+          // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup?
+
+          // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal
+          // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which
+          // are backed by fastutil hashmaps and the iterateNonZero actually does only return nonZeros.
+          while (sparseRowIterator.hasNext()) {
+            Vector.Element element = sparseRowIterator.next();
+            int col = element.index();
+            setQuick(row, col, function.apply(element.get(), other.getQuick(row, col)));
+          }
+        } else {
+          for (int col = 0; col < columns; col++) {
+            setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+          }
+        }
+
+      } catch (ClassCastException e) {
+        // Warn and use default implementation
+        log.warn("Error casting the row to SequentialAccessSparseVector, this should never happen because" +
+                "SparseRomMatrix is always made of SequentialAccessSparseVectors. Proceeding with non-optimzed" +
+                "implementation.");
+        for (int col = 0; col < columns; col++) {
+          setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col)));
+        }
+      }
+    }
+    return this;
+  }
+
+  @Override
   public Matrix assignColumn(int column, Vector other) {
     if (rowSize() != other.size()) {
       throw new CardinalityException(rowSize(), other.size());