You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2015/12/20 10:08:23 UTC
spark git commit: [SPARK-10158][PYSPARK][MLLIB] ALS better error
message when using Long IDs
Repository: spark
Updated Branches:
refs/heads/master 284e29a87 -> ce1798b3a
[SPARK-10158][PYSPARK][MLLIB] ALS better error message when using Long IDs
Added catch for casting Long to Int exception when PySpark ALS Ratings are serialized. It is easy to accidentally use Long IDs for user/product and before, it would fail with a somewhat cryptic "ClassCastException: java.lang.Long cannot be cast to java.lang.Integer." Now if this is done, a more descriptive error is shown, e.g. "PickleException: Ratings id 1205640308657491975 exceeds max integer value of 2147483647."
Author: Bryan Cutler <bj...@us.ibm.com>
Closes #9361 from BryanCutler/als-pyspark-long-id-error-SPARK-10158.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce1798b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce1798b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce1798b3
Branch: refs/heads/master
Commit: ce1798b3af8de326bf955b51ed955a924b019b4e
Parents: 284e29a
Author: Bryan Cutler <bj...@us.ibm.com>
Authored: Sun Dec 20 09:08:23 2015 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Dec 20 09:08:23 2015 +0000
----------------------------------------------------------------------
.../spark/mllib/api/python/PythonMLLibAPI.scala | 12 +++++++++++-
python/pyspark/mllib/tests.py | 17 +++++++++++++++++
2 files changed, 28 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ce1798b3/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 29160a1..f6826dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1438,9 +1438,19 @@ private[spark] object SerDe extends Serializable {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)),
args(2).asInstanceOf[Double])
}
+
+ private def ratingsIdCheckLong(obj: Object): Int = {
+ try {
+ obj.asInstanceOf[Int]
+ } catch {
+ case ex: ClassCastException =>
+ throw new PickleException(s"Ratings id ${obj.toString} exceeds " +
+ s"max integer value of ${Int.MaxValue}", ex)
+ }
+ }
}
var initialized = false
http://git-wip-us.apache.org/repos/asf/spark/blob/ce1798b3/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f8e8e0e..6ed03e3 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -54,6 +54,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -1539,6 +1540,22 @@ class MLUtilsTests(MLlibTestCase):
shutil.rmtree(load_vectors_path)
+class ALSTests(MLlibTestCase):
+
+ def test_als_ratings_serialize(self):
+ r = Rating(7, 1123, 3.14)
+ jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
+ nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
+ self.assertEqual(r.user, nr.user)
+ self.assertEqual(r.product, nr.product)
+ self.assertAlmostEqual(r.rating, nr.rating, 2)
+
+ def test_als_ratings_id_long_error(self):
+ r = Rating(1205640308657491975, 50233468418, 1.0)
+ # rating user id exceeds max int value, should fail when pickled
+ self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+
+
if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org