You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/02/01 23:13:36 UTC

spark git commit: [SPARK-5424][MLLIB] make the new ALS impl take generic ID types

Repository: spark
Updated Branches:
  refs/heads/master bdb0680d3 -> 4a171225b


[SPARK-5424][MLLIB] make the new ALS impl take generic ID types

This PR makes the ALS implementation take generic ID types, e.g., Long and String, and expose it as a developer API.

TODO:
- [x] make sure that specialization works (validated in profiler)

srowen You may like this change:) I hit a Scala compiler bug with specialization. It compiles now but users and items must have the same type. I'm going to check whether specialization really works.

Author: Xiangrui Meng <me...@databricks.com>

Closes #4281 from mengxr/generic-als and squashes the following commits:

96072c3 [Xiangrui Meng] merge master
135f741 [Xiangrui Meng] minor update
c2db5e5 [Xiangrui Meng] make test pass
86588e1 [Xiangrui Meng] use a single ID type for both users and items
74f1f73 [Xiangrui Meng] compile but runtime error at test
e36469a [Xiangrui Meng] add classtags and make it compile
7a5aeb3 [Xiangrui Meng] UserType -> User, ItemType -> Item
c8ee0bc [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into generic-als
72b5006 [Xiangrui Meng] remove generic from pipeline interface
8bbaea0 [Xiangrui Meng] make ALS take generic IDs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4a171225
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4a171225
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4a171225

Branch: refs/heads/master
Commit: 4a171225ba628192a5ae43a99dc50508cf12491c
Parents: bdb0680
Author: Xiangrui Meng <me...@databricks.com>
Authored: Sun Feb 1 14:13:31 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Feb 1 14:13:31 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 213 +++++++++++--------
 .../spark/ml/recommendation/ALSSuite.scala      |  36 ++--
 2 files changed, 146 insertions(+), 103 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4a171225/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index aaad548..979a19d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -20,16 +20,19 @@ package org.apache.spark.ml.recommendation
 import java.{util => ju}
 
 import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.Sorting
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
 import org.netlib.util.intW
 
 import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.Dsl._
 import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
 import org.apache.spark.util.Utils
@@ -199,7 +202,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     val ratings = dataset
       .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
       .map { row =>
-        new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+        Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
       }
     val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
       numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
@@ -215,10 +218,19 @@ class ALS extends Estimator[ALSModel] with ALSParams {
   }
 }
 
-private[recommendation] object ALS extends Logging {
+/**
+ * :: DeveloperApi ::
+ * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
+ * exposed as a developer API for users who do need other ID types. But it is not recommended
+ * because it increases the shuffle size and memory requirement during training. For simplicity,
+ * users and items must have the same type. The number of distinct users/items should be smaller
+ * than 2 billion.
+ */
+@DeveloperApi
+object ALS extends Logging {
 
   /** Rating class for better code readability. */
-  private[recommendation] case class Rating(user: Int, item: Int, rating: Float)
+  case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
 
   /** Cholesky solver for least square problems. */
   private[recommendation] class CholeskySolver {
@@ -285,7 +297,7 @@ private[recommendation] object ALS extends Logging {
 
     /** Adds an observation. */
     def add(a: Array[Float], b: Float): this.type = {
-      require(a.size == k)
+      require(a.length == k)
       copyToDouble(a)
       blas.dspr(upper, k, 1.0, da, 1, ata)
       blas.daxpy(k, b.toDouble, da, 1, atb, 1)
@@ -297,7 +309,7 @@ private[recommendation] object ALS extends Logging {
      * Adds an observation with implicit feedback. Note that this does not increment the counter.
      */
     def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
-      require(a.size == k)
+      require(a.length == k)
       // Extension to the original paper to handle b < 0. confidence is a function of |b| instead
       // so that it is never negative.
       val confidence = 1.0 + alpha * math.abs(b)
@@ -313,8 +325,8 @@ private[recommendation] object ALS extends Logging {
     /** Merges another normal equation object. */
     def merge(other: NormalEquation): this.type = {
       require(other.k == k)
-      blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
-      blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
+      blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
+      blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
       n += other.n
       this
     }
@@ -330,15 +342,16 @@ private[recommendation] object ALS extends Logging {
   /**
    * Implementation of the ALS algorithm.
    */
-  private def train(
-      ratings: RDD[Rating],
+  def train[ID: ClassTag](
+      ratings: RDD[Rating[ID]],
       rank: Int = 10,
       numUserBlocks: Int = 10,
       numItemBlocks: Int = 10,
       maxIter: Int = 10,
       regParam: Double = 1.0,
       implicitPrefs: Boolean = false,
-      alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = {
+      alpha: Double = 1.0)(
+      implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
     val userPart = new HashPartitioner(numUserBlocks)
     val itemPart = new HashPartitioner(numItemBlocks)
     val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
@@ -441,16 +454,15 @@ private[recommendation] object ALS extends Logging {
    *
    * @see [[LocalIndexEncoder]]
    */
-  private[recommendation] case class InBlock(
-      srcIds: Array[Int],
+  private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
+      srcIds: Array[ID],
       dstPtrs: Array[Int],
       dstEncodedIndices: Array[Int],
       ratings: Array[Float]) {
     /** Size of the block. */
-    val size: Int = ratings.size
-
-    require(dstEncodedIndices.size == size)
-    require(dstPtrs.size == srcIds.size + 1)
+    def size: Int = ratings.length
+    require(dstEncodedIndices.length == size)
+    require(dstPtrs.length == srcIds.length + 1)
   }
 
   /**
@@ -460,7 +472,9 @@ private[recommendation] object ALS extends Logging {
    * @param rank rank
    * @return initialized factor blocks
    */
-  private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
+  private def initialize[ID](
+      inBlocks: RDD[(Int, InBlock[ID])],
+      rank: Int): RDD[(Int, FactorBlock)] = {
     // Choose a unit vector uniformly at random from the unit sphere, but from the
     // "first quadrant" where all elements are nonnegative. This can be done by choosing
     // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
@@ -468,7 +482,7 @@ private[recommendation] object ALS extends Logging {
     // (<1%) compared picking elements uniformly at random in [0,1].
     inBlocks.map { case (srcBlockId, inBlock) =>
       val random = new XORShiftRandom(srcBlockId)
-      val factors = Array.fill(inBlock.srcIds.size) {
+      val factors = Array.fill(inBlock.srcIds.length) {
         val factor = Array.fill(rank)(random.nextGaussian().toFloat)
         val nrm = blas.snrm2(rank, factor, 1)
         blas.sscal(rank, 1.0f / nrm, factor, 1)
@@ -481,26 +495,29 @@ private[recommendation] object ALS extends Logging {
   /**
    * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
    */
-  private[recommendation]
-  case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) {
+  private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
+      srcIds: Array[ID],
+      dstIds: Array[ID],
+      ratings: Array[Float]) {
     /** Size of the block. */
-    val size: Int = srcIds.size
-    require(dstIds.size == size)
-    require(ratings.size == size)
+    def size: Int = srcIds.length
+    require(dstIds.length == srcIds.length)
+    require(ratings.length == srcIds.length)
   }
 
   /**
    * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
    */
-  private[recommendation] class RatingBlockBuilder extends Serializable {
+  private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag]
+    extends Serializable {
 
-    private val srcIds = mutable.ArrayBuilder.make[Int]
-    private val dstIds = mutable.ArrayBuilder.make[Int]
+    private val srcIds = mutable.ArrayBuilder.make[ID]
+    private val dstIds = mutable.ArrayBuilder.make[ID]
     private val ratings = mutable.ArrayBuilder.make[Float]
     var size = 0
 
     /** Adds a rating. */
-    def add(r: Rating): this.type = {
+    def add(r: Rating[ID]): this.type = {
       size += 1
       srcIds += r.user
       dstIds += r.item
@@ -509,8 +526,8 @@ private[recommendation] object ALS extends Logging {
     }
 
     /** Merges another [[RatingBlockBuilder]]. */
-    def merge(other: RatingBlock): this.type = {
-      size += other.srcIds.size
+    def merge(other: RatingBlock[ID]): this.type = {
+      size += other.srcIds.length
       srcIds ++= other.srcIds
       dstIds ++= other.dstIds
       ratings ++= other.ratings
@@ -518,8 +535,8 @@ private[recommendation] object ALS extends Logging {
     }
 
     /** Builds a [[RatingBlock]]. */
-    def build(): RatingBlock = {
-      RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
+    def build(): RatingBlock[ID] = {
+      RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
     }
   }
 
@@ -532,10 +549,10 @@ private[recommendation] object ALS extends Logging {
    *
    * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
    */
-  private def partitionRatings(
-      ratings: RDD[Rating],
+  private def partitionRatings[ID: ClassTag](
+      ratings: RDD[Rating[ID]],
       srcPart: Partitioner,
-      dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
+      dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
 
      /* The implementation produces the same result as the following but generates less objects.
 
@@ -549,7 +566,7 @@ private[recommendation] object ALS extends Logging {
 
     val numPartitions = srcPart.numPartitions * dstPart.numPartitions
     ratings.mapPartitions { iter =>
-      val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
+      val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID])
       iter.flatMap { r =>
         val srcBlockId = srcPart.getPartition(r.user)
         val dstBlockId = dstPart.getPartition(r.item)
@@ -570,7 +587,7 @@ private[recommendation] object ALS extends Logging {
         }
       }
     }.groupByKey().mapValues { blocks =>
-      val builder = new RatingBlockBuilder
+      val builder = new RatingBlockBuilder[ID]
       blocks.foreach(builder.merge)
       builder.build()
     }.setName("ratingBlocks")
@@ -580,9 +597,11 @@ private[recommendation] object ALS extends Logging {
    * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
    * @param encoder encoder for dst indices
    */
-  private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) {
+  private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
+      encoder: LocalIndexEncoder)(
+      implicit ord: Ordering[ID]) {
 
-    private val srcIds = mutable.ArrayBuilder.make[Int]
+    private val srcIds = mutable.ArrayBuilder.make[ID]
     private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
     private val ratings = mutable.ArrayBuilder.make[Float]
 
@@ -596,12 +615,12 @@ private[recommendation] object ALS extends Logging {
      */
     def add(
         dstBlockId: Int,
-        srcIds: Array[Int],
+        srcIds: Array[ID],
         dstLocalIndices: Array[Int],
         ratings: Array[Float]): this.type = {
-      val sz = srcIds.size
-      require(dstLocalIndices.size == sz)
-      require(ratings.size == sz)
+      val sz = srcIds.length
+      require(dstLocalIndices.length == sz)
+      require(ratings.length == sz)
       this.srcIds ++= srcIds
       this.ratings ++= ratings
       var j = 0
@@ -613,7 +632,7 @@ private[recommendation] object ALS extends Logging {
     }
 
     /** Builds a [[UncompressedInBlock]]. */
-    def build(): UncompressedInBlock = {
+    def build(): UncompressedInBlock[ID] = {
       new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
     }
   }
@@ -621,24 +640,25 @@ private[recommendation] object ALS extends Logging {
   /**
    * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
    */
-  private[recommendation] class UncompressedInBlock(
-      val srcIds: Array[Int],
+  private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag](
+      val srcIds: Array[ID],
       val dstEncodedIndices: Array[Int],
-      val ratings: Array[Float]) {
+      val ratings: Array[Float])(
+      implicit ord: Ordering[ID]) {
 
     /** Size the of block. */
-    def size: Int = srcIds.size
+    def length: Int = srcIds.length
 
     /**
      * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
      * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
      * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
      */
-    def compress(): InBlock = {
-      val sz = size
+    def compress(): InBlock[ID] = {
+      val sz = length
       assert(sz > 0, "Empty in-link block should not exist.")
       sort()
-      val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
+      val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID]
       val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
       var preSrcId = srcIds(0)
       uniqueSrcIdsBuilder += preSrcId
@@ -659,7 +679,7 @@ private[recommendation] object ALS extends Logging {
       }
       dstCountsBuilder += curCount
       val uniqueSrcIds = uniqueSrcIdsBuilder.result()
-      val numUniqueSrdIds = uniqueSrcIds.size
+      val numUniqueSrdIds = uniqueSrcIds.length
       val dstCounts = dstCountsBuilder.result()
       val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
       var sum = 0
@@ -673,51 +693,61 @@ private[recommendation] object ALS extends Logging {
     }
 
     private def sort(): Unit = {
-      val sz = size
+      val sz = length
       // Since there might be interleaved log messages, we insert a unique id for easy pairing.
       val sortId = Utils.random.nextInt()
       logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
       val start = System.nanoTime()
-      val sorter = new Sorter(new UncompressedInBlockSort)
-      sorter.sort(this, 0, size, Ordering[IntWrapper])
+      val sorter = new Sorter(new UncompressedInBlockSort[ID])
+      sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]])
       val duration = (System.nanoTime() - start) / 1e9
       logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
     }
   }
 
   /**
-   * A wrapper that holds a primitive integer key.
+   * A wrapper that holds a primitive key.
    *
    * @see [[UncompressedInBlockSort]]
    */
-  private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
-    override def compare(that: IntWrapper): Int = {
-      key.compare(that.key)
+  private class KeyWrapper[@specialized(Int, Long) ID: ClassTag](
+      implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
+
+    var key: ID = _
+
+    override def compare(that: KeyWrapper[ID]): Int = {
+      ord.compare(key, that.key)
+    }
+
+    def setKey(key: ID): this.type = {
+      this.key = key
+      this
     }
   }
 
   /**
    * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
    */
-  private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] {
+  private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag](
+      implicit ord: Ordering[ID])
+    extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] {
 
-    override def newKey(): IntWrapper = new IntWrapper()
+    override def newKey(): KeyWrapper[ID] = new KeyWrapper()
 
     override def getKey(
-        data: UncompressedInBlock,
+        data: UncompressedInBlock[ID],
         pos: Int,
-        reuse: IntWrapper): IntWrapper = {
+        reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
       if (reuse == null) {
-        new IntWrapper(data.srcIds(pos))
+        new KeyWrapper().setKey(data.srcIds(pos))
       } else {
-        reuse.key = data.srcIds(pos)
-        reuse
+        reuse.setKey(data.srcIds(pos))
       }
     }
 
     override def getKey(
-        data: UncompressedInBlock,
-        pos: Int): IntWrapper = {
+        data: UncompressedInBlock[ID],
+        pos: Int): KeyWrapper[ID] = {
       getKey(data, pos, null)
     }
 
@@ -730,16 +760,16 @@ private[recommendation] object ALS extends Logging {
       data(pos1) = tmp
     }
 
-    override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = {
+    override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = {
       swapElements(data.srcIds, pos0, pos1)
       swapElements(data.dstEncodedIndices, pos0, pos1)
       swapElements(data.ratings, pos0, pos1)
     }
 
     override def copyRange(
-        src: UncompressedInBlock,
+        src: UncompressedInBlock[ID],
         srcPos: Int,
-        dst: UncompressedInBlock,
+        dst: UncompressedInBlock[ID],
         dstPos: Int,
         length: Int): Unit = {
       System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
@@ -747,15 +777,15 @@ private[recommendation] object ALS extends Logging {
       System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
     }
 
-    override def allocate(length: Int): UncompressedInBlock = {
+    override def allocate(length: Int): UncompressedInBlock[ID] = {
       new UncompressedInBlock(
-        new Array[Int](length), new Array[Int](length), new Array[Float](length))
+        new Array[ID](length), new Array[Int](length), new Array[Float](length))
     }
 
     override def copyElement(
-        src: UncompressedInBlock,
+        src: UncompressedInBlock[ID],
         srcPos: Int,
-        dst: UncompressedInBlock,
+        dst: UncompressedInBlock[ID],
         dstPos: Int): Unit = {
       dst.srcIds(dstPos) = src.srcIds(srcPos)
       dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
@@ -771,19 +801,20 @@ private[recommendation] object ALS extends Logging {
    * @param dstPart partitioner for dst IDs
    * @return (in-blocks, out-blocks)
    */
-  private def makeBlocks(
+  private def makeBlocks[ID: ClassTag](
       prefix: String,
-      ratingBlocks: RDD[((Int, Int), RatingBlock)],
+      ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
       srcPart: Partitioner,
-      dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = {
+      dstPart: Partitioner)(
+      implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
     val inBlocks = ratingBlocks.map {
       case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
         // The implementation is a faster version of
         // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
         val start = System.nanoTime()
-        val dstIdSet = new OpenHashSet[Int](1 << 20)
+        val dstIdSet = new OpenHashSet[ID](1 << 20)
         dstIds.foreach(dstIdSet.add)
-        val sortedDstIds = new Array[Int](dstIdSet.size)
+        val sortedDstIds = new Array[ID](dstIdSet.size)
         var i = 0
         var pos = dstIdSet.nextPos(0)
         while (pos != -1) {
@@ -792,10 +823,10 @@ private[recommendation] object ALS extends Logging {
           i += 1
         }
         assert(i == dstIdSet.size)
-        ju.Arrays.sort(sortedDstIds)
-        val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
+        Sorting.quickSort(sortedDstIds)
+        val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length)
         i = 0
-        while (i < sortedDstIds.size) {
+        while (i < sortedDstIds.length) {
           dstIdToLocalIndex.update(sortedDstIds(i), i)
           i += 1
         }
@@ -806,7 +837,7 @@ private[recommendation] object ALS extends Logging {
     }.groupByKey(new HashPartitioner(srcPart.numPartitions))
         .mapValues { iter =>
       val builder =
-        new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions))
+        new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
       iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
         builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
       }
@@ -817,7 +848,7 @@ private[recommendation] object ALS extends Logging {
       val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
       var i = 0
       val seen = new Array[Boolean](dstPart.numPartitions)
-      while (i < srcIds.size) {
+      while (i < srcIds.length) {
         var j = dstPtrs(i)
         ju.Arrays.fill(seen, false)
         while (j < dstPtrs(i + 1)) {
@@ -851,16 +882,16 @@ private[recommendation] object ALS extends Logging {
    *
    * @return dst factors
    */
-  private def computeFactors(
+  private def computeFactors[ID](
       srcFactorBlocks: RDD[(Int, FactorBlock)],
       srcOutBlocks: RDD[(Int, OutBlock)],
-      dstInBlocks: RDD[(Int, InBlock)],
+      dstInBlocks: RDD[(Int, InBlock[ID])],
       rank: Int,
       regParam: Double,
       srcEncoder: LocalIndexEncoder,
       implicitPrefs: Boolean = false,
       alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
-    val numSrcBlocks = srcFactorBlocks.partitions.size
+    val numSrcBlocks = srcFactorBlocks.partitions.length
     val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
     val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
       case (srcBlockId, (srcOutBlock, srcFactors)) =>
@@ -868,18 +899,18 @@ private[recommendation] object ALS extends Logging {
           (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
         }
     }
-    val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
+    val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
     dstInBlocks.join(merged).mapValues {
       case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
         val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
         srcFactors.foreach { case (srcBlockId, factors) =>
           sortedSrcFactors(srcBlockId) = factors
         }
-        val dstFactors = new Array[Array[Float]](dstIds.size)
+        val dstFactors = new Array[Array[Float]](dstIds.length)
         var j = 0
         val ls = new NormalEquation(rank)
         val solver = new CholeskySolver // TODO: add NNLS solver
-        while (j < dstIds.size) {
+        while (j < dstIds.length) {
           ls.reset()
           if (implicitPrefs) {
             ls.merge(YtY.get)

http://git-wip-us.apache.org/repos/asf/spark/blob/4a171225/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9da253c..07aff56 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -155,7 +155,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
   }
 
   test("RatingBlockBuilder") {
-    val emptyBuilder = new RatingBlockBuilder()
+    val emptyBuilder = new RatingBlockBuilder[Int]()
     assert(emptyBuilder.size === 0)
     val emptyBlock = emptyBuilder.build()
     assert(emptyBlock.srcIds.isEmpty)
@@ -179,12 +179,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
 
   test("UncompressedInBlock") {
     val encoder = new LocalIndexEncoder(10)
-    val uncompressed = new UncompressedInBlockBuilder(encoder)
+    val uncompressed = new UncompressedInBlockBuilder[Int](encoder)
       .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
       .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
       .build()
-    assert(uncompressed.size === 5)
-    val records = Seq.tabulate(uncompressed.size) { i =>
+    assert(uncompressed.length === 5)
+    val records = Seq.tabulate(uncompressed.length) { i =>
       val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
       val dstBlockId = encoder.blockId(dstEncodedIndex)
       val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
@@ -228,15 +228,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
       numItems: Int,
       rank: Int,
       noiseStd: Double = 0.0,
-      seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+      seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = {
     val trainingFraction = 0.6
     val testFraction = 0.3
     val totalFraction = trainingFraction + testFraction
     val random = new Random(seed)
     val userFactors = genFactors(numUsers, rank, random)
     val itemFactors = genFactors(numItems, rank, random)
-    val training = ArrayBuffer.empty[Rating]
-    val test = ArrayBuffer.empty[Rating]
+    val training = ArrayBuffer.empty[Rating[Int]]
+    val test = ArrayBuffer.empty[Rating[Int]]
     for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
       val x = random.nextDouble()
       if (x < totalFraction) {
@@ -268,7 +268,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
       numItems: Int,
       rank: Int,
       noiseStd: Double = 0.0,
-      seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+      seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = {
     // The assumption of the implicit feedback model is that unobserved ratings are more likely to
     // be negatives.
     val positiveFraction = 0.8
@@ -279,8 +279,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     val random = new Random(seed)
     val userFactors = genFactors(numUsers, rank, random)
     val itemFactors = genFactors(numItems, rank, random)
-    val training = ArrayBuffer.empty[Rating]
-    val test = ArrayBuffer.empty[Rating]
+    val training = ArrayBuffer.empty[Rating[Int]]
+    val test = ArrayBuffer.empty[Rating[Int]]
     for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
       val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
       val threshold = if (rating > 0) positiveFraction else negativeFraction
@@ -340,8 +340,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
    * @param targetRMSE target test RMSE
    */
   def testALS(
-      training: RDD[Rating],
-      test: RDD[Rating],
+      training: RDD[Rating[Int]],
+      test: RDD[Rating[Int]],
       rank: Int,
       maxIter: Int,
       regParam: Double,
@@ -432,4 +432,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
       targetRMSE = 0.3)
   }
+
+  test("using generic ID types") {
+    val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+
+    val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating))
+    val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4)
+    assert(longUserFactors.first()._1.getClass === classOf[Long])
+
+    val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating))
+    val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
+    assert(strUserFactors.first()._1.getClass === classOf[String])
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org