You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2015/12/20 10:08:23 UTC

spark git commit: [SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs

Repository: spark
Updated Branches:
  refs/heads/master 284e29a87 -> ce1798b3a


[SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs

Added catch for casting Long to Int exception when PySpark ALS Ratings are serialized.  It is easy to accidentally use Long IDs for user/product and before, it would fail with a somewhat cryptic "ClassCastException: java.lang.Long cannot be cast to java.lang.Integer."  Now if this is done, a more descriptive error is shown, e.g. "PickleException: Ratings id 1205640308657491975 exceeds max integer value of 2147483647."

Author: Bryan Cutler <bj...@us.ibm.com>

Closes #9361 from BryanCutler/als-pyspark-long-id-error-SPARK-10158.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce1798b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce1798b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce1798b3

Branch: refs/heads/master
Commit: ce1798b3af8de326bf955b51ed955a924b019b4e
Parents: 284e29a
Author: Bryan Cutler <bj...@us.ibm.com>
Authored: Sun Dec 20 09:08:23 2015 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Dec 20 09:08:23 2015 +0000

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala    | 12 +++++++++++-
 python/pyspark/mllib/tests.py                      | 17 +++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ce1798b3/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 29160a1..f6826dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1438,9 +1438,19 @@ private[spark] object SerDe extends Serializable {
       if (args.length != 3) {
         throw new PickleException("should be 3")
       }
-      new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+      new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)),
         args(2).asInstanceOf[Double])
     }
+
+    private def ratingsIdCheckLong(obj: Object): Int = {
+      try {
+        obj.asInstanceOf[Int]
+      } catch {
+        case ex: ClassCastException =>
+          throw new PickleException(s"Ratings id ${obj.toString} exceeds " +
+            s"max integer value of ${Int.MaxValue}", ex)
+      }
+    }
   }
 
   var initialized = false

http://git-wip-us.apache.org/repos/asf/spark/blob/ce1798b3/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f8e8e0e..6ed03e3 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -54,6 +54,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
     DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
 from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.recommendation import Rating
 from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
 from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
@@ -1539,6 +1540,22 @@ class MLUtilsTests(MLlibTestCase):
             shutil.rmtree(load_vectors_path)
 
 
+class ALSTests(MLlibTestCase):
+
+    def test_als_ratings_serialize(self):
+        r = Rating(7, 1123, 3.14)
+        jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
+        nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
+        self.assertEqual(r.user, nr.user)
+        self.assertEqual(r.product, nr.product)
+        self.assertAlmostEqual(r.rating, nr.rating, 2)
+
+    def test_als_ratings_id_long_error(self):
+        r = Rating(1205640308657491975, 50233468418, 1.0)
+        # rating user id exceeds max int value, should fail when pickled
+        self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+
+
 if __name__ == "__main__":
     if not _have_scipy:
         print("NOTE: Skipping SciPy tests as it does not seem to be installed")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org