You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/06/11 02:09:10 UTC
[3/4] mahout git commit: MAHOUT-1660 MAHOUT-1713 MAHOUT-1714
MAHOUT-1715 MAHOUT-1716 MAHOUT-1717 MAHOUT-1718 MAHOUT-1719 MAHOUT-1720
MAHOUT-1721 MAHOUT-1722 MAHOUT-1723 MAHOUT-1724 MAHOUT-1725 MAHOUT-1726
MAHOUT-1727 MAHOUT-1728 MAHOUT-1729 MAHOUT-1730 M
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala
index 97e06cf..7091c53 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOps.scala
@@ -16,18 +16,28 @@
*/
package org.apache.mahout.math.scalabindings
+import org.apache.mahout.math.function.Functions
import org.apache.mahout.math.{Vector, Matrix}
import scala.collection.JavaConversions._
import RLikeOps._
class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) {
+ /** Structure-optimized mmul */
+ def %*%(that: Matrix) = MMul(m, that, None)
+
+ def :%*%(that:Matrix) = %*%(that)
+
+ def %*%:(that: Matrix) = that :%*% m
+
/**
- * matrix-matrix multiplication
- * @param that
- * @return
+ * The "legacy" matrix-matrix multiplication.
+ *
+ * @param that right hand operand
+ * @return matrix multiplication result
+ * @deprecated use %*%
*/
- def %*%(that: Matrix) = m.times(that)
+ def %***%(that: Matrix) = m.times(that)
/**
* matrix-vector multiplication
@@ -65,13 +75,16 @@ class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) {
* @param that
*/
def *=(that: Matrix) = {
- m.zip(that).foreach(t => t._1.vector *= t._2.vector)
+ m.assign(that, Functions.MULT)
m
}
+ /** A *=: B is equivalent to B *= A. Included for completeness. */
+ def *=:(that: Matrix) = m *= that
+
/** Elementwise deletion */
def /=(that: Matrix) = {
- m.zip(that).foreach(t => t._1.vector() /= t._2.vector)
+ m.zip(that).foreach(t ⇒ t._1.vector() /= t._2.vector)
m
}
@@ -80,15 +93,63 @@ class RLikeMatrixOps(m: Matrix) extends MatrixOps(m) {
m
}
+ /** 5.0 *=: A is equivalent to A *= 5.0. Included for completeness. */
+ def *=:(that: Double) = m *= that
+
def /=(that: Double) = {
- m.foreach(_.vector() /= that)
+ m ::= { x ⇒ x / that }
m
}
/** 1.0 /=: A is equivalent to A = 1.0/A in R */
def /=:(that: Double) = {
- m.foreach(that /=: _.vector())
+ if (that != 0.0) m := { x ⇒ that / x }
m
}
+
+ def ^=(that: Double) = {
+ m ::= { x ⇒ math.pow(x, that) }
+ m
+ }
+
+ def ^(that: Double) = m.cloned ^= that
+
+ def cbind(that: Matrix): Matrix = {
+ require(m.nrow == that.nrow)
+ if (m.ncol > 0) {
+ if (that.ncol > 0) {
+ val mx = m.like(m.nrow, m.ncol + that.ncol)
+ mx(::, 0 until m.ncol) := m
+ mx(::, m.ncol until mx.ncol) := that
+ mx
+ } else m
+ } else that
+ }
+
+ def cbind(that: Double): Matrix = {
+ val mx = m.like(m.nrow, m.ncol + 1)
+ mx(::, 0 until m.ncol) := m
+ if (that != 0.0) mx(::, m.ncol) := that
+ mx
+ }
+
+ def rbind(that: Matrix): Matrix = {
+ require(m.ncol == that.ncol)
+ if (m.nrow > 0) {
+ if (that.nrow > 0) {
+ val mx = m.like(m.nrow + that.nrow, m.ncol)
+ mx(0 until m.nrow, ::) := m
+ mx(m.nrow until mx.nrow, ::) := that
+ mx
+ } else m
+ } else that
+ }
+
+ def rbind(that: Double): Matrix = {
+ val mx = m.like(m.nrow + 1, m.ncol)
+ mx(0 until m.nrow, ::) := m
+ if (that != 0.0) mx(m.nrow, ::) := that
+ mx
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala
index ba32304..e10a01b 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeOps.scala
@@ -24,13 +24,13 @@ import org.apache.mahout.math.{Vector, MatrixTimesOps, Matrix}
*/
object RLikeOps {
- implicit def double2Scalar(x:Double) = new DoubleScalarOps(x)
+ implicit def double2Scalar(x:Double) = new RLikeDoubleScalarOps(x)
implicit def v2vOps(v: Vector) = new RLikeVectorOps(v)
implicit def el2elOps(el: Vector.Element) = new ElementOps(el)
- implicit def times2timesOps(m: MatrixTimesOps) = new RLikeTimesOps(m)
+ implicit def el2Double(el:Vector.Element) = el.get()
implicit def m2mOps(m: Matrix) = new RLikeMatrixOps(m)
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala
deleted file mode 100644
index 51f0f63..0000000
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeTimesOps.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.math.scalabindings
-
-import org.apache.mahout.math.{Matrix, MatrixTimesOps}
-
-class RLikeTimesOps(m: MatrixTimesOps) {
-
- def :%*%(that: Matrix) = m.timesRight(that)
-
- def %*%:(that: Matrix) = m.timesLeft(that)
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala
index d2198bd..38a55d6 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/RLikeVectorOps.scala
@@ -17,7 +17,7 @@
package org.apache.mahout.math.scalabindings
-import org.apache.mahout.math.Vector
+import org.apache.mahout.math.{Matrix, Vector}
import org.apache.mahout.math.function.Functions
import RLikeOps._
@@ -67,5 +67,32 @@ class RLikeVectorOps(_v: Vector) extends VectorOps(_v) {
/** Elementwise right-associative / */
def /:(that: Vector) = that.cloned /= v
+ def ^=(that: Double) = v.assign(Functions.POW, that)
+
+ def ^=(that: Vector) = v.assign(that, Functions.POW)
+
+ def ^(that: Double) = v.cloned ^= that
+
+ def ^(that: Vector) = v.cloned ^= that
+
+ def c(that: Vector) = {
+ if (v.length > 0) {
+ if (that.length > 0) {
+ val cv = v.like(v.length + that.length)
+ cv(0 until v.length) := cv
+ cv(v.length until cv.length) := that
+ cv
+ } else v
+ } else that
+ }
+
+ def c(that: Double) = {
+ val cv = v.like(v.length + 1)
+ cv(0 until v.length) := v
+ cv(v.length) = that
+ cv
+ }
+
+ def mean = sum / length
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala
index c20354d..ef9c494 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/VectorOps.scala
@@ -38,8 +38,13 @@ class VectorOps(private[scalabindings] val v: Vector) {
def update(r: Range, that: Vector) = apply(r) := that
+ /** R-like synonyms for java methods on vectors */
def sum = v.zSum()
+ def min = v.minValue()
+
+ def max = v.maxValue()
+
def :=(that: Vector): Vector = {
// assign op in Mahout requires same
@@ -58,11 +63,30 @@ class VectorOps(private[scalabindings] val v: Vector) {
def :=(that: Double): Vector = v.assign(that)
+ /** Functional assigment for a function with index and x */
def :=(f: (Int, Double) => Double): Vector = {
for (i <- 0 until length) v(i) = f(i, v(i))
v
}
+ /** Functional assignment for a function with just x (e.g. v := math.exp _) */
+ def :=(f:(Double)=>Double):Vector = {
+ for (i <- 0 until length) v(i) = f(v(i))
+ v
+ }
+
+ /** Sparse iteration functional assignment using function receiving index and x */
+ def ::=(f: (Int, Double) => Double): Vector = {
+ for (el <- v.nonZeroes) el := f(el.index, el.get)
+ v
+ }
+
+ /** Sparse iteration functional assignment using a function recieving just x */
+ def ::=(f: (Double) => Double): Vector = {
+ for (el <- v.nonZeroes) el := f(el.get)
+ v
+ }
+
def equiv(that: Vector) =
length == that.length &&
v.all.view.zip(that.all).forall(t => t._1.get == t._2.get)
@@ -121,21 +145,26 @@ class VectorOps(private[scalabindings] val v: Vector) {
}
class ElementOps(private[scalabindings] val el: Vector.Element) {
+ import RLikeOps._
+
+ def update(v: Double): Double = { el.set(v); v }
+
+ def :=(that: Double) = update(that)
- def apply = el.get()
+ def *(that: Vector.Element): Double = this * that
- def update(v: Double) = el.set(v)
+ def *(that: Vector): Vector = el.get * that
- def :=(v: Double) = el.set(v)
+ def +(that: Vector.Element): Double = this + that
- def +(that: Double) = el.get() + that
+ def +(that: Vector) :Vector = el.get + that
- def -(that: Double) = el.get() - that
+ def /(that: Vector.Element): Double = this / that
- def :-(that: Double) = that - el.get()
+ def /(that:Vector):Vector = el.get / that
- def /(that: Double) = el.get() / that
+ def -(that: Vector.Element): Double = this - that
- def :/(that: Double) = that / el.get()
+ def -(that: Vector) :Vector = el.get - that
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala
index 36f5103..20dc9cd 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/package.scala
@@ -18,12 +18,15 @@
package org.apache.mahout.math
import org.apache.mahout.math.solver.EigenDecomposition
+import collection._
+import JavaConversions._
/**
* Mahout matrices and vectors' scala syntactic sugar
*/
package object scalabindings {
+
// Reserved "ALL" range
final val `::`: Range = null
@@ -125,7 +128,6 @@ package object scalabindings {
val data = for (r <- rows) yield {
r match {
case n: Number => Array(n.doubleValue())
- case t: Product => t.productIterator.map(_.asInstanceOf[Number].doubleValue()).toArray
case t: Vector => Array.tabulate(t.length)(t(_))
case t: Array[Double] => t
case t: Iterable[_] =>
@@ -138,6 +140,7 @@ package object scalabindings {
}
return m
}
+ case t: Product => t.productIterator.map(_.asInstanceOf[Number].doubleValue()).toArray
case t: Array[Array[Double]] => if (rows.size == 1)
return new DenseMatrix(t)
else
@@ -164,7 +167,7 @@ package object scalabindings {
* (0,5)::(9,3)::Nil,
* (2,3.5)::(7,8)::Nil
* )
- *
+ *
* }}}
*
* @param rows
@@ -172,11 +175,18 @@ package object scalabindings {
*/
def sparse(rows: Vector*): SparseRowMatrix = {
- import MatrixOps._
+ import RLikeOps._
val nrow = rows.size
val ncol = rows.map(_.size()).max
val m = new SparseRowMatrix(nrow, ncol)
- m := rows
+ m := rows.map { row =>
+ if (row.length < ncol) {
+ val newRow = row.like(ncol)
+ newRow(0 until row.length) := row
+ newRow
+ }
+ else row
+ }
m
}
@@ -249,23 +259,23 @@ package object scalabindings {
(qrdec.getQ, qrdec.getR)
}
- /**
- * Solution <tt>X</tt> of <tt>A*X = B</tt> using QR-Decomposition, where <tt>A</tt> is a square, non-singular matrix.
+ /**
+ * Solution <tt>X</tt> of <tt>A*X = B</tt> using QR-Decomposition, where <tt>A</tt> is a square, non-singular matrix.
*
* @param a
* @param b
* @return (X)
*/
def solve(a: Matrix, b: Matrix): Matrix = {
- import MatrixOps._
- if (a.nrow != a.ncol) {
- throw new IllegalArgumentException("supplied matrix A is not square")
- }
- val qr = new QRDecomposition(a cloned)
- if (!qr.hasFullRank) {
- throw new IllegalArgumentException("supplied matrix A is singular")
- }
- qr.solve(b)
+ import MatrixOps._
+ if (a.nrow != a.ncol) {
+ throw new IllegalArgumentException("supplied matrix A is not square")
+ }
+ val qr = new QRDecomposition(a cloned)
+ if (!qr.hasFullRank) {
+ throw new IllegalArgumentException("supplied matrix A is singular")
+ }
+ qr.solve(b)
}
/**
@@ -293,5 +303,46 @@ package object scalabindings {
x(::, 0)
}
+ ///////////////////////////////////////////////////////////
+ // Elementwise unary functions. Actually this requires creating clones to avoid side effects. For
+ // efficiency reasons one may want to actually do in-place exression assignments instead, e.g.
+ //
+ // m := exp _
+
+ import RLikeOps._
+ import scala.math._
+
+ def mexp(m: Matrix): Matrix = m.cloned := exp _
+
+ def vexp(v: Vector): Vector = v.cloned := exp _
+
+ def mlog(m: Matrix): Matrix = m.cloned := log _
+
+ def vlog(v: Vector): Vector = v.cloned := log _
+
+ def mabs(m: Matrix): Matrix = m.cloned ::= (abs(_: Double))
+
+ def vabs(v: Vector): Vector = v.cloned ::= (abs(_: Double))
+
+ def msqrt(m: Matrix): Matrix = m.cloned ::= sqrt _
+
+ def vsqrt(v: Vector): Vector = v.cloned ::= sqrt _
+
+ def msignum(m: Matrix): Matrix = m.cloned ::= (signum(_: Double))
+
+ def vsignum(v: Vector): Vector = v.cloned ::= (signum(_: Double))
+
+ //////////////////////////////////////////////////////////
+ // operation funcs
+
+
+ /** Matrix-matrix unary func */
+ type MMUnaryFunc = (Matrix, Option[Matrix]) => Matrix
+ /** Binary matrix-matrix operations which may save result in-place, optionally */
+ type MMBinaryFunc = (Matrix, Matrix, Option[Matrix]) => Matrix
+ type MVBinaryFunc = (Matrix, Vector, Option[Matrix]) => Matrix
+ type VMBinaryFunc = (Vector, Matrix, Option[Matrix]) => Matrix
+ type MDBinaryFunc = (Matrix, Double, Option[Matrix]) => Matrix
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala b/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala
new file mode 100644
index 0000000..b61bea4
--- /dev/null
+++ b/math-scala/src/main/scala/org/apache/mahout/util/IOUtilsScala.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.util
+
+import org.apache.mahout.logging._
+import collection._
+import java.io.Closeable
+
+object IOUtilsScala {
+
+ private final implicit val log = getLog(IOUtilsScala.getClass)
+
+ /**
+ * Try to close every resource in the sequence, in order of the sequence.
+ *
+ * Report all encountered exceptions to logging.
+ *
+ * Rethrow last exception only (if any)
+ * @param closeables
+ */
+ def close(closeables: Seq[Closeable]) = {
+
+ var lastThr: Option[Throwable] = None
+ closeables.foreach { c =>
+ try {
+ c.close()
+ } catch {
+ case t: Throwable =>
+ error(t.getMessage, t)
+ lastThr = Some(t)
+ }
+ }
+
+ // Rethrow most recent close exception (can throw only one)
+ lastThr.foreach(throw _)
+ }
+
+ /**
+ * Same as [[IOUtilsScala.close( )]] but do not re-throw any exceptions.
+ * @param closeables
+ */
+ def closeQuietly(closeables: Seq[Closeable]) = {
+ try {
+ close(closeables)
+ } catch {
+ case t: Throwable => // NOP
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala
index 849db68..bb42121 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeOpsSuiteBase.scala
@@ -46,6 +46,26 @@ trait DrmLikeOpsSuiteBase extends DistributedMahoutSuite with Matchers {
}
+ test("allReduceBlock") {
+
+ val mxA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6))
+ val drmA = drmParallelize(mxA, numPartitions = 2)
+
+ try {
+ val mxB = drmA.allreduceBlock { case (keys, block) ⇒
+ block(::, 0 until 2).t %*% block(::, 2 until 3)
+ }
+
+ val mxControl = mxA(::, 0 until 2).t %*% mxA(::, 2 until 3)
+
+ (mxB - mxControl).norm should be < 1e-10
+
+ } catch {
+ case e: UnsupportedOperationException ⇒ // Some engines may not support this, so ignore.
+ }
+
+ }
+
test("col range") {
val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6))
val A = drmParallelize(m = inCoreA, numPartitions = 2)
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala
index 6c9313c..f215fb7 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/DrmLikeSuiteBase.scala
@@ -68,9 +68,8 @@ trait DrmLikeSuiteBase extends DistributedMahoutSuite with Matchers {
inCoreEmpty.nrow shouldBe 100
inCoreEmpty.ncol shouldBe 50
+ }
- }
-
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala
index 2e6204d..b46ee30 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/drm/RLikeDrmOpsSuiteBase.scala
@@ -24,7 +24,13 @@ import scalabindings._
import RLikeOps._
import RLikeDrmOps._
import decompositions._
-import org.apache.mahout.math.drm.logical.{OpAtB, OpAtA, OpAtx}
+import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.drm.logical.OpAtx
+import org.apache.mahout.math.drm.logical.OpAtB
+import org.apache.mahout.math.drm.logical.OpAtA
+import org.apache.mahout.math.drm.logical.OpAewUnaryFuncFusion
+
+import scala.util.Random
/** Common engine tests for distributed R-like DRM operations */
trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers {
@@ -188,10 +194,13 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers {
val A = drmParallelize(inCoreA, numPartitions = 2)
.mapBlock()({
- case (keys, block) => keys.map(_.toString) -> block
+ case (keys, block) ⇒ keys.map(_.toString) → block
})
- val B = A + 1.0
+ // Dense-A' x sparse-B used to produce error. We sparsify B here to test this as well.
+ val B = (A + 1.0).mapBlock() { case (keys, block) ⇒
+ keys → (new SparseRowMatrix(block.nrow, block.ncol) := block)
+ }
val C = A.t %*% B
@@ -204,6 +213,25 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers {
}
+ test ("C = A %*% B.t") {
+
+ val inCoreA = dense((1, 2), (3, 4), (-3, -5))
+
+ val A = drmParallelize(inCoreA, numPartitions = 2)
+
+ val B = A + 1.0
+
+ val C = A %*% B.t
+
+ mahoutCtx.optimizerRewrite(C) should equal(OpABt[Int](A, B))
+
+ val inCoreC = C.collect
+ val inCoreControlC = inCoreA %*% (inCoreA + 1.0).t
+
+ (inCoreC - inCoreControlC).norm should be < 1E-10
+
+ }
+
test("C = A %*% inCoreB") {
val inCoreA = dense((1, 2, 3), (3, 4, 5), (4, 5, 6), (5, 6, 7))
@@ -503,6 +531,24 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers {
}
+ test("B = 1 cbind A") {
+ val inCoreA = dense((1, 2), (3, 4))
+ val control = dense((1, 1, 2), (1, 3, 4))
+
+ val drmA = drmParallelize(inCoreA, numPartitions = 2)
+
+ (control - (1 cbind drmA) ).norm should be < 1e-10
+ }
+
+ test("B = A cbind 1") {
+ val inCoreA = dense((1, 2), (3, 4))
+ val control = dense((1, 2, 1), (3, 4, 1))
+
+ val drmA = drmParallelize(inCoreA, numPartitions = 2)
+
+ (control - (drmA cbind 1) ).norm should be < 1e-10
+ }
+
test("B = A + 1.0") {
val inCoreA = dense((1, 2), (2, 3), (3, 4))
val controlB = inCoreA + 1.0
@@ -547,4 +593,46 @@ trait RLikeDrmOpsSuiteBase extends DistributedMahoutSuite with Matchers {
(10 * drmA - (10 *: drmA)).norm shouldBe 0
}
+
+ test("A * A -> sqr(A) rewrite ") {
+ val mxA = dense(
+ (1, 2, 3),
+ (3, 4, 5),
+ (7, 8, 9)
+ )
+
+ val mxAAControl = mxA * mxA
+
+ val drmA = drmParallelize(mxA, 2)
+ val drmAA = drmA * drmA
+
+ val optimized = drmAA.context.engine.optimizerRewrite(drmAA)
+ println(s"optimized:$optimized")
+ optimized.isInstanceOf[OpAewUnaryFunc[Int]] shouldBe true
+
+ (mxAAControl -= drmAA).norm should be < 1e-10
+ }
+
+ test("B = 1 + 2 * (A * A) ew unary function fusion") {
+ val mxA = dense(
+ (1, 2, 3),
+ (3, 0, 5)
+ )
+ val controlB = mxA.cloned := { (x) => 1 + 2 * x * x}
+
+ val drmA = drmParallelize(mxA, 2)
+
+ // We need to use parenthesis, otherwise optimizer will see it as (2A) * (A) and that would not
+ // be rewritten as 2 * sqr(A). It is not that clever (yet) to try commutativity optimizations.
+ val drmB = 1 + 2 * (drmA * drmA)
+
+ val optimized = mahoutCtx.engine.optimizerRewrite(drmB)
+ println(s"optimizer rewritten:$optimized")
+ optimized.isInstanceOf[OpAewUnaryFuncFusion[Int]] shouldBe true
+
+ (controlB - drmB).norm should be < 1e-10
+
+ }
+
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala
index d7b22d9..5c8a310 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala
@@ -24,6 +24,8 @@ import org.apache.mahout.test.MahoutSuite
import org.apache.mahout.math.{RandomAccessSparseVector, SequentialAccessSparseVector, Matrices}
import org.apache.mahout.common.RandomUtils
+import scala.util.Random
+
class MatrixOpsSuite extends FunSuite with MahoutSuite {
@@ -93,12 +95,40 @@ class MatrixOpsSuite extends FunSuite with MahoutSuite {
val e = eye(5)
- printf("I(5)=\n%s\n", e)
+ println(s"I(5)=\n$e")
a(0 to 1, 1 to 2) = dense((3, 2), (2, 3))
a(0 to 1, 1 to 2) := dense((3, 2), (2, 3))
+ println(s"a=$a")
+
+ a(0 to 1, 1 to 2) := { _ => 45}
+ println(s"a=$a")
+
+// a(0 to 1, 1 to 2) ::= { _ => 44}
+ println(s"a=$a")
+
+ // Sparse assignment to a sparse block
+ val c = sparse(0 -> 1 :: Nil, 2 -> 2 :: Nil, 1 -> 5 :: Nil)
+ val d = c.cloned
+
+ println(s"d=$d")
+ d.ncol shouldBe 3
+ d(::, 1 to 2) ::= { _ => 4}
+ println(s"d=$d")
+ d(::, 1 to 2).sum shouldBe 8
+
+ d ::= {_ => 5}
+ d.sum shouldBe 15
+
+ val f = c.cloned.t
+ f ::= {_ => 6}
+ f.sum shouldBe 18
+
+ val g = c.cloned
+ g(::, 1 until g.nrow) ::= { x => if (x <= 0) 0.0 else 1.0}
+ g.sum shouldBe 3
}
test("sparse") {
@@ -182,4 +212,5 @@ class MatrixOpsSuite extends FunSuite with MahoutSuite {
}
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala
index a943c5f..79d2899 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala
@@ -17,9 +17,16 @@
package org.apache.mahout.math.scalabindings
+import java.util
+
+import org.apache.log4j.Level
+import org.apache.mahout.math._
import org.scalatest.FunSuite
import RLikeOps._
import org.apache.mahout.test.MahoutSuite
+import org.apache.mahout.logging._
+import scala.collection.JavaConversions._
+import scala.util.Random
class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite {
@@ -63,6 +70,10 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite {
}
+ test("Uniform view") {
+ val mxUnif = Matrices.symmetricUniformView(5000000, 5000000, 1234)
+ }
+
/** Test dsl overloads over scala operations over matrices */
test ("scalarOps") {
val a = dense(
@@ -77,4 +88,269 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite {
}
+ test("Multiplication experimental performance") {
+
+ getLog(MMul.getClass).setLevel(Level.DEBUG)
+
+ val d = 300
+ val n = 3
+
+ // Dense row-wise
+ val mxAd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) + 1
+ val mxBd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) - 1
+
+ val rnd = new Random(1234)
+
+ // Sparse rows
+ val mxAsr = (new SparseRowMatrix(d,
+ d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned
+ val mxBsr = (new SparseRowMatrix(d,
+ d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned
+
+ // Hanging sparse rows
+ val mxAs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned
+ val mxBs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned
+
+ // DIAGONAL
+ val mxD = diagv(dvec(Array.tabulate(d)(_ => rnd.nextGaussian())))
+
+ def time(op: => Unit): Long = {
+ val ms = System.currentTimeMillis()
+ op
+ System.currentTimeMillis() - ms
+ }
+
+ def getMmulAvgs(mxA: Matrix, mxB: Matrix, n: Int) = {
+
+ var control: Matrix = null
+ var mmulVal: Matrix = null
+
+ val current = Stream.range(0, n).map { _ => time {control = mxA.times(mxB)} }.sum.toDouble / n
+ val experimental = Stream.range(0, n).map { _ => time {mmulVal = MMul(mxA, mxB, None)} }.sum.toDouble / n
+ (control - mmulVal).norm should be < 1e-10
+ current -> experimental
+ }
+
+ // Dense matrix tests.
+ println(s"Ad %*% Bd: ${getMmulAvgs(mxAd, mxBd, n)}")
+ println(s"Ad' %*% Bd: ${getMmulAvgs(mxAd.t, mxBd, n)}")
+ println(s"Ad %*% Bd': ${getMmulAvgs(mxAd, mxBd.t, n)}")
+ println(s"Ad' %*% Bd': ${getMmulAvgs(mxAd.t, mxBd.t, n)}")
+ println(s"Ad'' %*% Bd'': ${getMmulAvgs(mxAd.t.t, mxBd.t.t, n)}")
+ println
+
+ // Sparse row matrix tests.
+ println(s"Asr %*% Bsr: ${getMmulAvgs(mxAsr, mxBsr, n)}")
+ println(s"Asr' %*% Bsr: ${getMmulAvgs(mxAsr.t, mxBsr, n)}")
+ println(s"Asr %*% Bsr': ${getMmulAvgs(mxAsr, mxBsr.t, n)}")
+ println(s"Asr' %*% Bsr': ${getMmulAvgs(mxAsr.t, mxBsr.t, n)}")
+ println(s"Asr'' %*% Bsr'': ${getMmulAvgs(mxAsr.t.t, mxBsr.t.t, n)}")
+ println
+
+ // Sparse matrix tests.
+ println(s"Asm %*% Bsm: ${getMmulAvgs(mxAs, mxBs, n)}")
+ println(s"Asm' %*% Bsm: ${getMmulAvgs(mxAs.t, mxBs, n)}")
+ println(s"Asm %*% Bsm': ${getMmulAvgs(mxAs, mxBs.t, n)}")
+ println(s"Asm' %*% Bsm': ${getMmulAvgs(mxAs.t, mxBs.t, n)}")
+ println(s"Asm'' %*% Bsm'': ${getMmulAvgs(mxAs.t.t, mxBs.t.t, n)}")
+ println
+
+ // Mixed sparse matrix tests.
+ println(s"Asm %*% Bsr: ${getMmulAvgs(mxAs, mxBsr, n)}")
+ println(s"Asm' %*% Bsr: ${getMmulAvgs(mxAs.t, mxBsr, n)}")
+ println(s"Asm %*% Bsr': ${getMmulAvgs(mxAs, mxBsr.t, n)}")
+ println(s"Asm' %*% Bsr': ${getMmulAvgs(mxAs.t, mxBsr.t, n)}")
+ println(s"Asm'' %*% Bsr'': ${getMmulAvgs(mxAs.t.t, mxBsr.t.t, n)}")
+ println
+
+ println(s"Asr %*% Bsm: ${getMmulAvgs(mxAsr, mxBs, n)}")
+ println(s"Asr' %*% Bsm: ${getMmulAvgs(mxAsr.t, mxBs, n)}")
+ println(s"Asr %*% Bsm': ${getMmulAvgs(mxAsr, mxBs.t, n)}")
+ println(s"Asr' %*% Bsm': ${getMmulAvgs(mxAsr.t, mxBs.t, n)}")
+ println(s"Asr'' %*% Bsm'': ${getMmulAvgs(mxAsr.t.t, mxBs.t.t, n)}")
+ println
+
+ // Mixed dense/sparse
+ println(s"Ad %*% Bsr: ${getMmulAvgs(mxAd, mxBsr, n)}")
+ println(s"Ad' %*% Bsr: ${getMmulAvgs(mxAd.t, mxBsr, n)}")
+ println(s"Ad %*% Bsr': ${getMmulAvgs(mxAd, mxBsr.t, n)}")
+ println(s"Ad' %*% Bsr': ${getMmulAvgs(mxAd.t, mxBsr.t, n)}")
+ println(s"Ad'' %*% Bsr'': ${getMmulAvgs(mxAd.t.t, mxBsr.t.t, n)}")
+ println
+
+ println(s"Asr %*% Bd: ${getMmulAvgs(mxAsr, mxBd, n)}")
+ println(s"Asr' %*% Bd: ${getMmulAvgs(mxAsr.t, mxBd, n)}")
+ println(s"Asr %*% Bd': ${getMmulAvgs(mxAsr, mxBd.t, n)}")
+ println(s"Asr' %*% Bd': ${getMmulAvgs(mxAsr.t, mxBd.t, n)}")
+ println(s"Asr'' %*% Bd'': ${getMmulAvgs(mxAsr.t.t, mxBd.t.t, n)}")
+ println
+
+ println(s"Ad %*% Bsm: ${getMmulAvgs(mxAd, mxBs, n)}")
+ println(s"Ad' %*% Bsm: ${getMmulAvgs(mxAd.t, mxBs, n)}")
+ println(s"Ad %*% Bsm': ${getMmulAvgs(mxAd, mxBs.t, n)}")
+ println(s"Ad' %*% Bsm': ${getMmulAvgs(mxAd.t, mxBs.t, n)}")
+ println(s"Ad'' %*% Bsm'': ${getMmulAvgs(mxAd.t.t, mxBs.t.t, n)}")
+ println
+
+ println(s"Asm %*% Bd: ${getMmulAvgs(mxAs, mxBd, n)}")
+ println(s"Asm' %*% Bd: ${getMmulAvgs(mxAs.t, mxBd, n)}")
+ println(s"Asm %*% Bd': ${getMmulAvgs(mxAs, mxBd.t, n)}")
+ println(s"Asm' %*% Bd': ${getMmulAvgs(mxAs.t, mxBd.t, n)}")
+ println(s"Asm'' %*% Bd'': ${getMmulAvgs(mxAs.t.t, mxBd.t.t, n)}")
+ println
+
+ // Diagonal cases
+ println(s"Ad %*% D: ${getMmulAvgs(mxAd, mxD, n)}")
+ println(s"Asr %*% D: ${getMmulAvgs(mxAsr, mxD, n)}")
+ println(s"Asm %*% D: ${getMmulAvgs(mxAs, mxD, n)}")
+ println(s"D %*% Ad: ${getMmulAvgs(mxD, mxAd, n)}")
+ println(s"D %*% Asr: ${getMmulAvgs(mxD, mxAsr, n)}")
+ println(s"D %*% Asm: ${getMmulAvgs(mxD, mxAs, n)}")
+ println
+
+ println(s"Ad' %*% D: ${getMmulAvgs(mxAd.t, mxD, n)}")
+ println(s"Asr' %*% D: ${getMmulAvgs(mxAsr.t, mxD, n)}")
+ println(s"Asm' %*% D: ${getMmulAvgs(mxAs.t, mxD, n)}")
+ println(s"D %*% Ad': ${getMmulAvgs(mxD, mxAd.t, n)}")
+ println(s"D %*% Asr': ${getMmulAvgs(mxD, mxAsr.t, n)}")
+ println(s"D %*% Asm': ${getMmulAvgs(mxD, mxAs.t, n)}")
+ println
+
+ // Self-squared cases
+ println(s"Ad %*% Ad': ${getMmulAvgs(mxAd, mxAd.t, n)}")
+ println(s"Ad' %*% Ad: ${getMmulAvgs(mxAd.t, mxAd, n)}")
+ println(s"Ad' %*% Ad'': ${getMmulAvgs(mxAd.t, mxAd.t.t, n)}")
+ println(s"Ad'' %*% Ad': ${getMmulAvgs(mxAd.t.t, mxAd.t, n)}")
+
+ }
+
+
+ test("elementwise experimental performance") {
+
+ val d = 500
+ val n = 3
+
+ // Dense row-wise
+ val mxAd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) + 1
+ val mxBd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) - 1
+
+ val rnd = new Random(1234)
+
+ // Sparse rows
+ val mxAsr = (new SparseRowMatrix(d,
+ d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned
+ val mxBsr = (new SparseRowMatrix(d,
+ d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned
+
+ // Hanging sparse rows
+ val mxAs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() + 1 else 0.0 }) cloned
+ val mxBs = (new SparseMatrix(d, d) := { _ => if (rnd.nextDouble() < 0.1) rnd.nextGaussian() - 1 else 0.0 }) cloned
+
+ // DIAGONAL
+ val mxD = diagv(dvec(Array.tabulate(d)(_ => rnd.nextGaussian())))
+
+ def time(op: => Unit): Long = {
+ val ms = System.currentTimeMillis()
+ op
+ System.currentTimeMillis() - ms
+ }
+
+ def getEWAvgs(mxA: Matrix, mxB: Matrix, n: Int) = {
+
+ var control: Matrix = null
+ var mmulVal: Matrix = null
+
+ val current = Stream.range(0, n).map { _ => time {control = mxA + mxB} }.sum.toDouble / n
+ val experimental = Stream.range(0, n).map { _ => time {mmulVal = mxA + mxB} }.sum.toDouble / n
+ (control - mmulVal).norm should be < 1e-10
+ current -> experimental
+ }
+
+ // Dense matrix tests.
+ println(s"Ad + Bd: ${getEWAvgs(mxAd, mxBd, n)}")
+ println(s"Ad' + Bd: ${getEWAvgs(mxAd.t, mxBd, n)}")
+ println(s"Ad + Bd': ${getEWAvgs(mxAd, mxBd.t, n)}")
+ println(s"Ad' + Bd': ${getEWAvgs(mxAd.t, mxBd.t, n)}")
+ println(s"Ad'' + Bd'': ${getEWAvgs(mxAd.t.t, mxBd.t.t, n)}")
+ println
+
+ // Sparse row matrix tests.
+ println(s"Asr + Bsr: ${getEWAvgs(mxAsr, mxBsr, n)}")
+ println(s"Asr' + Bsr: ${getEWAvgs(mxAsr.t, mxBsr, n)}")
+ println(s"Asr + Bsr': ${getEWAvgs(mxAsr, mxBsr.t, n)}")
+ println(s"Asr' + Bsr': ${getEWAvgs(mxAsr.t, mxBsr.t, n)}")
+ println(s"Asr'' + Bsr'': ${getEWAvgs(mxAsr.t.t, mxBsr.t.t, n)}")
+ println
+
+ // Sparse matrix tests.
+ println(s"Asm + Bsm: ${getEWAvgs(mxAs, mxBs, n)}")
+ println(s"Asm' + Bsm: ${getEWAvgs(mxAs.t, mxBs, n)}")
+ println(s"Asm + Bsm': ${getEWAvgs(mxAs, mxBs.t, n)}")
+ println(s"Asm' + Bsm': ${getEWAvgs(mxAs.t, mxBs.t, n)}")
+ println(s"Asm'' + Bsm'': ${getEWAvgs(mxAs.t.t, mxBs.t.t, n)}")
+ println
+
+ // Mixed sparse matrix tests.
+ println(s"Asm + Bsr: ${getEWAvgs(mxAs, mxBsr, n)}")
+ println(s"Asm' + Bsr: ${getEWAvgs(mxAs.t, mxBsr, n)}")
+ println(s"Asm + Bsr': ${getEWAvgs(mxAs, mxBsr.t, n)}")
+ println(s"Asm' + Bsr': ${getEWAvgs(mxAs.t, mxBsr.t, n)}")
+ println(s"Asm'' + Bsr'': ${getEWAvgs(mxAs.t.t, mxBsr.t.t, n)}")
+ println
+
+ println(s"Asr + Bsm: ${getEWAvgs(mxAsr, mxBs, n)}")
+ println(s"Asr' + Bsm: ${getEWAvgs(mxAsr.t, mxBs, n)}")
+ println(s"Asr + Bsm': ${getEWAvgs(mxAsr, mxBs.t, n)}")
+ println(s"Asr' + Bsm': ${getEWAvgs(mxAsr.t, mxBs.t, n)}")
+ println(s"Asr'' + Bsm'': ${getEWAvgs(mxAsr.t.t, mxBs.t.t, n)}")
+ println
+
+ // Mixed dense/sparse
+ println(s"Ad + Bsr: ${getEWAvgs(mxAd, mxBsr, n)}")
+ println(s"Ad' + Bsr: ${getEWAvgs(mxAd.t, mxBsr, n)}")
+ println(s"Ad + Bsr': ${getEWAvgs(mxAd, mxBsr.t, n)}")
+ println(s"Ad' + Bsr': ${getEWAvgs(mxAd.t, mxBsr.t, n)}")
+ println(s"Ad'' + Bsr'': ${getEWAvgs(mxAd.t.t, mxBsr.t.t, n)}")
+ println
+
+ println(s"Asr + Bd: ${getEWAvgs(mxAsr, mxBd, n)}")
+ println(s"Asr' + Bd: ${getEWAvgs(mxAsr.t, mxBd, n)}")
+ println(s"Asr + Bd': ${getEWAvgs(mxAsr, mxBd.t, n)}")
+ println(s"Asr' + Bd': ${getEWAvgs(mxAsr.t, mxBd.t, n)}")
+ println(s"Asr'' + Bd'': ${getEWAvgs(mxAsr.t.t, mxBd.t.t, n)}")
+ println
+
+ println(s"Ad + Bsm: ${getEWAvgs(mxAd, mxBs, n)}")
+ println(s"Ad' + Bsm: ${getEWAvgs(mxAd.t, mxBs, n)}")
+ println(s"Ad + Bsm': ${getEWAvgs(mxAd, mxBs.t, n)}")
+ println(s"Ad' + Bsm': ${getEWAvgs(mxAd.t, mxBs.t, n)}")
+ println(s"Ad'' + Bsm'': ${getEWAvgs(mxAd.t.t, mxBs.t.t, n)}")
+ println
+
+ println(s"Asm + Bd: ${getEWAvgs(mxAs, mxBd, n)}")
+ println(s"Asm' + Bd: ${getEWAvgs(mxAs.t, mxBd, n)}")
+ println(s"Asm + Bd': ${getEWAvgs(mxAs, mxBd.t, n)}")
+ println(s"Asm' + Bd': ${getEWAvgs(mxAs.t, mxBd.t, n)}")
+ println(s"Asm'' + Bd'': ${getEWAvgs(mxAs.t.t, mxBd.t.t, n)}")
+ println
+
+ // Diagonal cases
+ println(s"Ad + D: ${getEWAvgs(mxAd, mxD, n)}")
+ println(s"Asr + D: ${getEWAvgs(mxAsr, mxD, n)}")
+ println(s"Asm + D: ${getEWAvgs(mxAs, mxD, n)}")
+ println(s"D + Ad: ${getEWAvgs(mxD, mxAd, n)}")
+ println(s"D + Asr: ${getEWAvgs(mxD, mxAsr, n)}")
+ println(s"D + Asm: ${getEWAvgs(mxD, mxAs, n)}")
+ println
+
+ println(s"Ad' + D: ${getEWAvgs(mxAd.t, mxD, n)}")
+ println(s"Asr' + D: ${getEWAvgs(mxAsr.t, mxD, n)}")
+ println(s"Asm' + D: ${getEWAvgs(mxAs.t, mxD, n)}")
+ println(s"D + Ad': ${getEWAvgs(mxD, mxAd.t, n)}")
+ println(s"D + Asr': ${getEWAvgs(mxD, mxAsr.t, n)}")
+ println(s"D + Asm': ${getEWAvgs(mxD, mxAs.t, n)}")
+ println
+
+ }
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala
index 037f562..d264514 100644
--- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala
+++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala
@@ -18,10 +18,12 @@
package org.apache.mahout.math.scalabindings
import org.scalatest.FunSuite
-import org.apache.mahout.math.{RandomAccessSparseVector, Vector}
+import org.apache.mahout.math.{SequentialAccessSparseVector, RandomAccessSparseVector, Vector}
import RLikeOps._
import org.apache.mahout.test.MahoutSuite
+import scala.util.Random
+
/** VectorOps Suite */
class VectorOpsSuite extends FunSuite with MahoutSuite {
@@ -79,4 +81,19 @@ class VectorOpsSuite extends FunSuite with MahoutSuite {
}
+ test("sparse assignment") {
+
+ val svec = new SequentialAccessSparseVector(30)
+ svec(1) = -0.5
+ svec(3) = 0.5
+ println(svec)
+
+ svec(1 until svec.length) ::= ( _ => 0)
+ println(svec)
+
+ svec.sum shouldBe 0
+
+
+ }
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java b/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
index e752422..a823d0b 100644
--- a/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
@@ -19,13 +19,16 @@ package org.apache.mahout.math;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Maps;
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
import org.apache.mahout.math.function.*;
import java.util.Iterator;
import java.util.Map;
/**
- * A few universal implementations of convenience functions
+ * A few universal implementations of convenience functions for a JVM-backed matrix.
*/
public abstract class AbstractMatrix implements Matrix {
@@ -57,19 +60,24 @@ public abstract class AbstractMatrix implements Matrix {
@Override
public Iterator<MatrixSlice> iterateAll() {
return new AbstractIterator<MatrixSlice>() {
- private int slice;
+ private int row;
@Override
protected MatrixSlice computeNext() {
- if (slice >= numSlices()) {
+ if (row >= numRows()) {
return endOfData();
}
- int i = slice++;
+ int i = row++;
return new MatrixSlice(viewRow(i), i);
}
};
}
+ @Override
+ public Iterator<MatrixSlice> iterateNonEmpty() {
+ return iterator();
+ }
+
/**
* Abstracted out for the iterator
*
@@ -813,4 +821,12 @@ public abstract class AbstractMatrix implements Matrix {
return returnString + ("}");
}
}
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ throw new UnsupportedOperationException("Flavor support not implemented for this matrix.");
+ }
+
+ ////////////// Matrix flavor trait ///////////////////
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/ConstantVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/ConstantVector.java b/math/src/main/java/org/apache/mahout/math/ConstantVector.java
index 86ab82b..847bf85 100644
--- a/math/src/main/java/org/apache/mahout/math/ConstantVector.java
+++ b/math/src/main/java/org/apache/mahout/math/ConstantVector.java
@@ -132,6 +132,11 @@ public class ConstantVector extends AbstractVector {
return new DenseVector(size());
}
+ @Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
/**
* Set the value at the given index, without checking bounds
*
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java b/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
index a1fd291..0b2e36b 100644
--- a/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
+++ b/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
@@ -310,6 +310,11 @@ public class DelegatingVector implements Vector, LengthCachingVector {
}
@Override
+ public Vector like(int cardinality) {
+ return new DelegatingVector(delegate.like(cardinality));
+ }
+
+ @Override
public void setQuick(int index, double value) {
delegate.setQuick(index, value);
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java b/math/src/main/java/org/apache/mahout/math/DenseMatrix.java
index 7f52c00..5c1ee12 100644
--- a/math/src/main/java/org/apache/mahout/math/DenseMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/DenseMatrix.java
@@ -17,6 +17,9 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
import java.util.Arrays;
/** Matrix of doubles implemented using a 2-d array */
@@ -175,5 +178,9 @@ public class DenseMatrix extends AbstractMatrix {
}
return new DenseVector(values[row], true);
}
-
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.DENSELIKE;
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java b/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
index e9cf3f1..7252b9b 100644
--- a/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/DenseSymmetricMatrix.java
@@ -17,6 +17,8 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
/**
* Economy packaging for a dense symmetric in-core matrix.
*/
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DenseVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DenseVector.java b/math/src/main/java/org/apache/mahout/math/DenseVector.java
index 5b3dea7..3633e58 100644
--- a/math/src/main/java/org/apache/mahout/math/DenseVector.java
+++ b/math/src/main/java/org/apache/mahout/math/DenseVector.java
@@ -136,6 +136,11 @@ public class DenseVector extends AbstractVector {
}
@Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ @Override
public void setQuick(int index, double value) {
invalidateCachedLength();
values[index] = value;
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
index 3e20a4a..070fad2 100644
--- a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
@@ -17,6 +17,9 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
import java.util.Iterator;
import java.util.NoSuchElementException;
@@ -223,6 +226,11 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
}
@Override
+ public Vector like(int cardinality) {
+ return new DenseVector(cardinality);
+ }
+
+ @Override
public void setQuick(int index, double value) {
if (index == this.index) {
diagonal.set(this.index, value);
@@ -361,4 +369,10 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
}
return m;
}
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.DIAGONALLIKE;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java b/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
index ba09aa8..56600cd 100644
--- a/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/FileBasedSparseBinaryMatrix.java
@@ -437,6 +437,11 @@ public final class FileBasedSparseBinaryMatrix extends AbstractMatrix {
return new RandomAccessSparseVector(size());
}
+ @Override
+ public Vector like(int cardinality) {
+ return new RandomAccessSparseVector(cardinality);
+ }
+
/**
* Copy the vector for fast operations.
*
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java b/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
index 2a13611..9028e23 100644
--- a/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
+++ b/math/src/main/java/org/apache/mahout/math/FunctionalMatrixView.java
@@ -17,6 +17,9 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
import org.apache.mahout.math.function.IntIntFunction;
/**
@@ -29,6 +32,7 @@ class FunctionalMatrixView extends AbstractMatrix {
*/
private IntIntFunction gf;
private boolean denseLike;
+ private MatrixFlavor flavor;
public FunctionalMatrixView(int rows, int columns, IntIntFunction gf) {
this(rows, columns, gf, false);
@@ -42,6 +46,7 @@ class FunctionalMatrixView extends AbstractMatrix {
super(rows, columns);
this.gf = gf;
this.denseLike = denseLike;
+ flavor = new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.BLOCKIFIED, denseLike);
}
@Override
@@ -87,4 +92,8 @@ class FunctionalMatrixView extends AbstractMatrix {
return new MatrixVectorView(this, 0, column, 1, 0, denseLike);
}
+ @Override
+ public MatrixFlavor getFlavor() {
+ return flavor;
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Matrices.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/Matrices.java b/math/src/main/java/org/apache/mahout/math/Matrices.java
index 4a0c50c..fc45a16 100644
--- a/math/src/main/java/org/apache/mahout/math/Matrices.java
+++ b/math/src/main/java/org/apache/mahout/math/Matrices.java
@@ -17,7 +17,9 @@
package org.apache.mahout.math;
+import com.google.common.base.Preconditions;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.IntIntFunction;
@@ -63,16 +65,14 @@ public final class Matrices {
* @return transposed view of original matrix
*/
public static final Matrix transposedView(final Matrix m) {
- IntIntFunction tf = new IntIntFunction() {
- @Override
- public double apply(int row, int col) {
- return m.getQuick(col, row);
- }
- };
- // TODO: Matrix api does not support denseLike() interrogation.
- // so our guess has to be rough here.
- return functionalMatrixView(m.numCols(), m.numRows(), tf, m instanceof DenseMatrix);
+ Preconditions.checkArgument(!(m instanceof SparseColumnMatrix));
+
+ if (m instanceof TransposedMatrixView) {
+ return ((TransposedMatrixView) m).getDelegate();
+ } else {
+ return new TransposedMatrixView(m);
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Matrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/Matrix.java b/math/src/main/java/org/apache/mahout/math/Matrix.java
index afdbac5..47ba5cf 100644
--- a/math/src/main/java/org/apache/mahout/math/Matrix.java
+++ b/math/src/main/java/org/apache/mahout/math/Matrix.java
@@ -17,6 +17,7 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.MatrixFlavor;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.VectorFunction;
@@ -403,4 +404,10 @@ public interface Matrix extends Cloneable, VectorIterable {
* @return A vector that shares storage with the original matrix.
*/
Vector viewDiagonal();
+
+ /**
+ * Get matrix structural flavor (operations performance hints). This is optional operation, may
+ * throw {@link java.lang.UnsupportedOperationException}.
+ */
+ MatrixFlavor getFlavor();
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java b/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
index 074d7a6..52ae722 100644
--- a/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
+++ b/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
@@ -211,6 +211,11 @@ public class MatrixVectorView extends AbstractVector {
return matrix.like(size(), 1).viewColumn(0);
}
+ @Override
+ public Vector like(int cardinality) {
+ return matrix.like(cardinality, 1).viewColumn(0);
+ }
+
/**
* Set the value at the given index, without checking bounds
*
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/MatrixView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/MatrixView.java b/math/src/main/java/org/apache/mahout/math/MatrixView.java
index e2f7f48..86760d5 100644
--- a/math/src/main/java/org/apache/mahout/math/MatrixView.java
+++ b/math/src/main/java/org/apache/mahout/math/MatrixView.java
@@ -17,6 +17,8 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+
/** Implements subset view of a Matrix */
public class MatrixView extends AbstractMatrix {
@@ -151,4 +153,8 @@ public class MatrixView extends AbstractMatrix {
return new VectorView(matrix.viewRow(row + offset[ROW]), offset[COL], columnSize());
}
+ @Override
+ public MatrixFlavor getFlavor() {
+ return matrix.getFlavor();
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/NamedVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/NamedVector.java b/math/src/main/java/org/apache/mahout/math/NamedVector.java
index 0bf49c8..d4fa609 100644
--- a/math/src/main/java/org/apache/mahout/math/NamedVector.java
+++ b/math/src/main/java/org/apache/mahout/math/NamedVector.java
@@ -177,6 +177,11 @@ public class NamedVector implements Vector {
}
@Override
+ public Vector like(int cardinality) {
+ return new NamedVector(delegate.like(cardinality), name);
+ }
+
+ @Override
public Vector minus(Vector x) {
return delegate.minus(x);
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java b/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java
index f34f2b0..a76f78c 100644
--- a/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java
+++ b/math/src/main/java/org/apache/mahout/math/PermutedVectorView.java
@@ -204,6 +204,11 @@ public class PermutedVectorView extends AbstractVector {
return vector.like();
}
+ @Override
+ public Vector like(int cardinality) {
+ return vector.like(cardinality);
+ }
+
/**
* Set the value at the given index, without checking bounds
*
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
index dbe5d3a..3efac7e 100644
--- a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
+++ b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
@@ -142,6 +142,11 @@ public class RandomAccessSparseVector extends AbstractVector {
}
@Override
+ public Vector like(int cardinality) {
+ return new RandomAccessSparseVector(cardinality, values.size());
+ }
+
+ @Override
public int getNumNondefaultElements() {
return values.size();
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
index 331662c..f7d67a7 100644
--- a/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
+++ b/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
@@ -180,6 +180,11 @@ public class SequentialAccessSparseVector extends AbstractVector {
}
@Override
+ public Vector like(int cardinality) {
+ return new SequentialAccessSparseVector(cardinality);
+ }
+
+ @Override
public int getNumNondefaultElements() {
return values.getNumMappings();
}
@@ -214,6 +219,8 @@ public class SequentialAccessSparseVector extends AbstractVector {
@Override
public Iterator<Element> iterateNonZero() {
+
+ // TODO: this is a bug, since nonDefaultIterator doesn't hold to non-zero contract.
return new NonDefaultIterator();
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
index f62d553..eeffc78 100644
--- a/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/SparseColumnMatrix.java
@@ -17,9 +17,13 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
/**
* sparse matrix with general element values whose columns are accessible quickly. Implemented as a column array of
* SparseVectors.
+ *
+ * @deprecated tons of inconsistences. Use transpose view of SparseRowMatrix for fast column-wise iteration.
*/
public class SparseColumnMatrix extends AbstractMatrix {
@@ -31,11 +35,19 @@ public class SparseColumnMatrix extends AbstractMatrix {
* @param columns a RandomAccessSparseVector[] array of columns
* @param columnVectors
*/
- public SparseColumnMatrix(int rows, int columns, RandomAccessSparseVector[] columnVectors) {
+ public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors) {
+ this(rows, columns, columnVectors, false);
+ }
+
+ public SparseColumnMatrix(int rows, int columns, Vector[] columnVectors, boolean shallow) {
super(rows, columns);
- this.columnVectors = columnVectors.clone();
- for (int col = 0; col < columnSize(); col++) {
- this.columnVectors[col] = this.columnVectors[col].clone();
+ if (shallow) {
+ this.columnVectors = columnVectors;
+ } else {
+ this.columnVectors = columnVectors.clone();
+ for (int col = 0; col < columnSize(); col++) {
+ this.columnVectors[col] = this.columnVectors[col].clone();
+ }
}
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseMatrix.java
index 88e15a0..bf4f1a0 100644
--- a/math/src/main/java/org/apache/mahout/math/SparseMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/SparseMatrix.java
@@ -18,6 +18,8 @@
package org.apache.mahout.math;
import com.google.common.collect.AbstractIterator;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.IntObjectProcedure;
@@ -40,11 +42,23 @@ public class SparseMatrix extends AbstractMatrix {
* @param columns
* @param rowVectors
*/
- public SparseMatrix(int rows, int columns, Map<Integer, RandomAccessSparseVector> rowVectors) {
+ public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors) {
+ this(rows, columns, rowVectors, false);
+ }
+
+ public SparseMatrix(int rows, int columns, Map<Integer, Vector> rowVectors, boolean shallow) {
+
+ // Why this is passing in a map? iterating it is pretty inefficient as opposed to simple lists...
super(rows, columns);
this.rowVectors = new OpenIntObjectHashMap<Vector>();
- for (Map.Entry<Integer, RandomAccessSparseVector> entry : rowVectors.entrySet()) {
- this.rowVectors.put(entry.getKey(), entry.getValue().clone());
+ if (shallow) {
+ for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
+ this.rowVectors.put(entry.getKey(), entry.getValue());
+ }
+ } else {
+ for (Map.Entry<Integer, Vector> entry : rowVectors.entrySet()) {
+ this.rowVectors.put(entry.getKey(), entry.getValue().clone());
+ }
}
}
@@ -66,7 +80,11 @@ public class SparseMatrix extends AbstractMatrix {
}
@Override
- public Iterator<MatrixSlice> iterator() {
+ public int numSlices() {
+ return rowVectors.size();
+ }
+
+ public Iterator<MatrixSlice> iterateNonEmpty() {
final IntArrayList keys = new IntArrayList(rowVectors.size());
rowVectors.keys(keys);
return new AbstractIterator<MatrixSlice>() {
@@ -221,4 +239,8 @@ public class SparseMatrix extends AbstractMatrix {
return rowVectors.keys();
}
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.SPARSEROWLIKE;
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
index 3021f3b..6e06769 100644
--- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
@@ -17,6 +17,8 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
import org.apache.mahout.math.function.Functions;
/**
@@ -226,4 +228,9 @@ public class SparseRowMatrix extends AbstractMatrix {
}
}
}
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return MatrixFlavor.SPARSELIKE;
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java b/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
new file mode 100644
index 0000000..c67cb47
--- /dev/null
+++ b/math/src/main/java/org/apache/mahout/math/TransposedMatrixView.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Matrix View backed by an {@link org.apache.mahout.math.function.IntIntFunction}
+ */
+class TransposedMatrixView extends AbstractMatrix {
+
+ private Matrix m;
+
+ public TransposedMatrixView(Matrix m) {
+ super(m.numCols(), m.numRows());
+ this.m = m;
+ }
+
+ @Override
+ public Matrix assignColumn(int column, Vector other) {
+ m.assignRow(column,other);
+ return this;
+ }
+
+ @Override
+ public Matrix assignRow(int row, Vector other) {
+ m.assignColumn(row,other);
+ return this;
+ }
+
+ @Override
+ public double getQuick(int row, int column) {
+ return m.getQuick(column,row);
+ }
+
+ @Override
+ public Matrix like() {
+ return m.like(rows, columns);
+ }
+
+ @Override
+ public Matrix like(int rows, int columns) {
+ return m.like(rows,columns);
+ }
+
+ @Override
+ public void setQuick(int row, int column, double value) {
+ m.setQuick(column, row, value);
+ }
+
+ @Override
+ public Vector viewRow(int row) {
+ return m.viewColumn(row);
+ }
+
+ @Override
+ public Vector viewColumn(int column) {
+ return m.viewRow(column);
+ }
+
+ @Override
+ public Matrix assign(double value) {
+ return m.assign(value);
+ }
+
+ @Override
+ public Matrix assign(Matrix other, DoubleDoubleFunction function) {
+ if (other instanceof TransposedMatrixView) {
+ m.assign(((TransposedMatrixView) other).m, function);
+ } else {
+ m.assign(new TransposedMatrixView(other), function);
+ }
+ return this;
+ }
+
+ @Override
+ public Matrix assign(Matrix other) {
+ if (other instanceof TransposedMatrixView) {
+ return m.assign(((TransposedMatrixView) other).m);
+ } else {
+ return m.assign(new TransposedMatrixView(other));
+ }
+ }
+
+ @Override
+ public Matrix assign(DoubleFunction function) {
+ return m.assign(function);
+ }
+
+ @Override
+ public MatrixFlavor getFlavor() {
+ return flavor;
+ }
+
+ private MatrixFlavor flavor = new MatrixFlavor() {
+ @Override
+ public BackEnum getBacking() {
+ return m.getFlavor().getBacking();
+ }
+
+ @Override
+ public TraversingStructureEnum getStructure() {
+ TraversingStructureEnum flavor = m.getFlavor().getStructure();
+ switch (flavor) {
+ case COLWISE:
+ return TraversingStructureEnum.ROWWISE;
+ case SPARSECOLWISE:
+ return TraversingStructureEnum.SPARSEROWWISE;
+ case ROWWISE:
+ return TraversingStructureEnum.COLWISE;
+ case SPARSEROWWISE:
+ return TraversingStructureEnum.SPARSECOLWISE;
+ default:
+ return flavor;
+ }
+ }
+
+ @Override
+ public boolean isDense() {
+ return m.getFlavor().isDense();
+ }
+ };
+
+ Matrix getDelegate() {
+ return m;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java b/math/src/main/java/org/apache/mahout/math/UpperTriangular.java
index a0cb3cd..29fa6a0 100644
--- a/math/src/main/java/org/apache/mahout/math/UpperTriangular.java
+++ b/math/src/main/java/org/apache/mahout/math/UpperTriangular.java
@@ -17,6 +17,10 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.flavor.BackEnum;
+import org.apache.mahout.math.flavor.MatrixFlavor;
+import org.apache.mahout.math.flavor.TraversingStructureEnum;
+
/**
*
* Quick and dirty implementation of some {@link org.apache.mahout.math.Matrix} methods
@@ -148,4 +152,9 @@ public class UpperTriangular extends AbstractMatrix {
return values;
}
+ @Override
+ public MatrixFlavor getFlavor() {
+ // We kind of consider ourselves a vector-backed but dense matrix for mmul, etc. purposes.
+ return new MatrixFlavor.FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, true);
+ }
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/Vector.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/Vector.java b/math/src/main/java/org/apache/mahout/math/Vector.java
index 0d1a003..4480b0a 100644
--- a/math/src/main/java/org/apache/mahout/math/Vector.java
+++ b/math/src/main/java/org/apache/mahout/math/Vector.java
@@ -190,6 +190,14 @@ public interface Vector extends Cloneable {
Vector like();
/**
+ * Return a new empty vector of the same underlying class as the receiver with given cardinality
+ *
+ * @param cardinality
+ * @return
+ */
+ Vector like(int cardinality);
+
+ /**
* Return a new vector containing the element by element difference of the recipient and the argument
*
* @param x a Vector
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/VectorIterable.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/VectorIterable.java b/math/src/main/java/org/apache/mahout/math/VectorIterable.java
index 451c589..8414fdb 100644
--- a/math/src/main/java/org/apache/mahout/math/VectorIterable.java
+++ b/math/src/main/java/org/apache/mahout/math/VectorIterable.java
@@ -21,8 +21,12 @@ import java.util.Iterator;
public interface VectorIterable extends Iterable<MatrixSlice> {
+ /* Iterate all rows in order */
Iterator<MatrixSlice> iterateAll();
+ /* Iterate all non empty rows in arbitrary order */
+ Iterator<MatrixSlice> iterateNonEmpty();
+
int numSlices();
int numRows();
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/VectorView.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/VectorView.java b/math/src/main/java/org/apache/mahout/math/VectorView.java
index b503712..d61a038 100644
--- a/math/src/main/java/org/apache/mahout/math/VectorView.java
+++ b/math/src/main/java/org/apache/mahout/math/VectorView.java
@@ -69,6 +69,11 @@ public class VectorView extends AbstractVector {
}
@Override
+ public Vector like(int cardinality) {
+ return vector.like(cardinality);
+ }
+
+ @Override
public double getQuick(int index) {
return vector.getQuick(offset + index);
}
@@ -122,7 +127,7 @@ public class VectorView extends AbstractVector {
while (it.hasNext()) {
Element el = it.next();
if (isInView(el.index()) && el.get() != 0) {
- Element decorated = vector.getElement(el.index());
+ Element decorated = el; /* vector.getElement(el.index()); */
return new DecoratorElement(decorated);
}
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java b/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
new file mode 100644
index 0000000..1782f04
--- /dev/null
+++ b/math/src/main/java/org/apache/mahout/math/flavor/BackEnum.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/**
+ * Matrix backends
+ */
+public enum BackEnum {
+ JVMMEM,
+ NETLIB_BLAS
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java b/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
new file mode 100644
index 0000000..2b5c444
--- /dev/null
+++ b/math/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/** A set of matrix structure properties that I denote as "flavor" (by analogy to quarks) */
+public interface MatrixFlavor {
+
+ /**
+ * Whether matrix is backed by a native system -- such as java memory, lapack/atlas, Magma etc.
+ */
+ BackEnum getBacking();
+
+ /**
+ * Structure flavors
+ */
+ TraversingStructureEnum getStructure() ;
+
+ boolean isDense();
+
+ /**
+ * This default for {@link org.apache.mahout.math.DenseMatrix}-like structures
+ */
+ static final MatrixFlavor DENSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, true);
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.SparseRowMatrix}-like.
+ */
+ static final MatrixFlavor SPARSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, false);
+
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.SparseMatrix}-like structures, i.e. sparse matrix blocks,
+ * where few, perhaps most, rows may be missing entirely.
+ */
+ static final MatrixFlavor SPARSEROWLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.SPARSEROWWISE, false);
+
+ /**
+ * This is default flavor for {@link org.apache.mahout.math.DiagonalMatrix} and the likes.
+ */
+ static final MatrixFlavor DIAGONALLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, false);
+
+ static final class FlavorImpl implements MatrixFlavor {
+ private BackEnum pBacking;
+ private TraversingStructureEnum pStructure;
+ private boolean pDense;
+
+ public FlavorImpl(BackEnum backing, TraversingStructureEnum structure, boolean dense) {
+ pBacking = backing;
+ pStructure = structure;
+ pDense = dense;
+ }
+
+ @Override
+ public BackEnum getBacking() {
+ return pBacking;
+ }
+
+ @Override
+ public TraversingStructureEnum getStructure() {
+ return pStructure;
+ }
+
+ @Override
+ public boolean isDense() {
+ return pDense;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java b/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
new file mode 100644
index 0000000..13c2cf4
--- /dev/null
+++ b/math/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.flavor;
+
+/** STRUCTURE HINT */
+public enum TraversingStructureEnum {
+
+ UNKNOWN,
+
+ /**
+ * Backing vectors are directly available as row views.
+ */
+ ROWWISE,
+
+ /**
+ * Column vectors are directly available as column views.
+ */
+ COLWISE,
+
+ /**
+ * Only some row-wise vectors are really present (can use iterateNonEmpty). Corresponds to
+ * [[org.apache.mahout.math.SparseMatrix]].
+ */
+ SPARSEROWWISE,
+
+ SPARSECOLWISE,
+
+ SPARSEHASH,
+
+ VECTORBACKED,
+
+ BLOCKIFIED
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/math/src/test/java/org/apache/mahout/math/MatricesTest.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/MatricesTest.java b/math/src/test/java/org/apache/mahout/math/MatricesTest.java
index 1b6169e..9405429 100644
--- a/math/src/test/java/org/apache/mahout/math/MatricesTest.java
+++ b/math/src/test/java/org/apache/mahout/math/MatricesTest.java
@@ -65,8 +65,8 @@ public class MatricesTest extends MahoutTestCase {
m.set(1, 1, 33.0);
Matrix mt = Matrices.transposedView(m);
- assertTrue(!mt.viewColumn(0).isDense());
- assertTrue(!mt.viewRow(0).isDense());
+ assertTrue(mt.viewColumn(0).isDense() == m.viewRow(0).isDense());
+ assertTrue(mt.viewRow(0).isDense() == m.viewColumn(0).isDense());
m = new DenseMatrix(10,10);
m.set(1, 1, 33.0);
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
index 1a6ff16..de5e216 100644
--- a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
@@ -133,6 +133,11 @@ public class DistributedRowMatrix implements VectorIterable, Configurable {
}
@Override
+ public Iterator<MatrixSlice> iterateNonEmpty() {
+ return iterator();
+ }
+
+ @Override
public Iterator<MatrixSlice> iterateAll() {
try {
Path pathPattern = rowPath;
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
index 7033efe..af79cb4 100644
--- a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
@@ -586,6 +586,11 @@ public class GivensThinSolver {
}
@Override
+ public Vector like(int cardinality) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
public void setQuick(int index, double value) {
viewed.setQuick(rowNum, index, value);
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
----------------------------------------------------------------------
diff --git a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
index 5ffc18c..4d0615a 100644
--- a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
+++ b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
@@ -12,13 +12,14 @@ class MahoutSparkILoop extends SparkILoop {
private val postInitScript =
"import org.apache.mahout.math._" ::
- "import scalabindings._" ::
- "import RLikeOps._" ::
- "import drm._" ::
- "import RLikeDrmOps._" ::
- "import org.apache.mahout.sparkbindings._" ::
- "import collection.JavaConversions._" ::
- Nil
+ "import scalabindings._" ::
+ "import RLikeOps._" ::
+ "import drm._" ::
+ "import RLikeDrmOps._" ::
+ "import decompositions._" ::
+ "import org.apache.mahout.sparkbindings._" ::
+ "import collection.JavaConversions._" ::
+ Nil
override protected def postInitialization() {
super.postInitialization()
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/pom.xml
----------------------------------------------------------------------
diff --git a/spark/pom.xml b/spark/pom.xml
index 33e0d1b..7155115 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -119,6 +119,22 @@
</executions>
</plugin>
+ <!-- create test jar so other modules can reuse the math test utility classes.
+ DO NOT REMOVE! Testing framework is useful in subordinate/contrib projects!
+ Please contact @dlyubimov.
+ -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ <phase>package</phase>
+ </execution>
+ </executions>
+ </plugin>
</plugins>
</build>
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala b/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala
index 5bbccb1..0aba319 100644
--- a/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala
+++ b/spark/src/main/scala/org/apache/mahout/common/DrmMetadata.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.mahout.common
import scala.reflect.ClassTag
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala b/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala
index f5f87d7..c949f92 100644
--- a/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala
+++ b/spark/src/main/scala/org/apache/mahout/common/HDFSUtil.scala
@@ -17,10 +17,12 @@
package org.apache.mahout.common
+import org.apache.spark.SparkContext
+
/** High level Hadoop version-specific hdfs manipulations we need in context of our operations. */
trait HDFSUtil {
/** Read DRM header information off (H)DFS. */
- def readDrmHeader(path:String):DrmMetadata
+ def readDrmHeader(path:String)(implicit sc:SparkContext):DrmMetadata
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/8a6b805a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala
index 047104a..399508d 100644
--- a/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala
+++ b/spark/src/main/scala/org/apache/mahout/common/Hadoop1HDFSUtil.scala
@@ -17,10 +17,10 @@
package org.apache.mahout.common
-
import org.apache.hadoop.io.{Writable, SequenceFile}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext
import collection._
import JavaConversions._
@@ -30,14 +30,16 @@ import JavaConversions._
*/
object Hadoop1HDFSUtil extends HDFSUtil {
- /**
- * Read the header of a sequence file and determine the Key and Value type
- * @param path
- * @return
- */
- def readDrmHeader(path: String): DrmMetadata = {
+
+ /** Read DRM header information off (H)DFS. */
+ override def readDrmHeader(path: String)(implicit sc: SparkContext): DrmMetadata = {
+
val dfsPath = new Path(path)
- val fs = dfsPath.getFileSystem(new Configuration())
+
+ val fs = dfsPath.getFileSystem(sc.hadoopConfiguration)
+
+ // Apparently getFileSystem() doesn't set conf??
+ fs.setConf(sc.hadoopConfiguration)
val partFilePath:Path = fs.listStatus(dfsPath)