You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/01/22 15:35:26 UTC
[spark] branch master updated: [SPARK-16838][PYTHON] Add PMML
export for ML KMeans in PySpark
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 06792af [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark
06792af is described below
commit 06792afd4c9c719df4af34b0768a999271383330
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Tue Jan 22 09:34:59 2019 -0600
[SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark
## What changes were proposed in this pull request?
Add PMML export support for ML KMeans to PySpark.
## How was this patch tested?
Add tests in ml.tests.PersistenceTest.
Closes #23592 from huaxingao/spark-16838.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Sean Owen <se...@databricks.com>
---
python/pyspark/ml/clustering.py | 2 +-
python/pyspark/ml/tests/test_persistence.py | 37 +++++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 5a776ae..b9c6bdf 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -323,7 +323,7 @@ class KMeansSummary(ClusteringSummary):
return self._call_java("trainingCost")
-class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
index 34d6870..63b0594 100644
--- a/python/pyspark/ml/tests/test_persistence.py
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -23,6 +23,7 @@ import unittest
from pyspark.ml import Transformer
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
OneVsRestModel
+from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import Binarizer, HashingTF, PCA
from pyspark.ml.linalg import Vectors
from pyspark.ml.param import Params
@@ -89,6 +90,42 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
+ def test_kmeans(self):
+ kmeans = KMeans(k=2, seed=1)
+ path = tempfile.mkdtemp()
+ km_path = path + "/km"
+ kmeans.save(km_path)
+ kmeans2 = KMeans.load(km_path)
+ self.assertEqual(kmeans.uid, kmeans2.uid)
+ self.assertEqual(type(kmeans.uid), type(kmeans2.uid))
+ self.assertEqual(kmeans2.uid, kmeans2.k.parent,
+ "Loaded KMeans instance uid (%s) did not match Param's uid (%s)"
+ % (kmeans2.uid, kmeans2.k.parent))
+ self.assertEqual(kmeans._defaultParamMap[kmeans.k], kmeans2._defaultParamMap[kmeans2.k],
+ "Loaded KMeans instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def test_kmean_pmml_basic(self):
+ # Most of the validation is done in the Scala side, here we just check
+ # that we output text rather than parquet (e.g. that the format flag
+ # was respected).
+ data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+ (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ kmeans = KMeans(k=2, seed=1)
+ model = kmeans.fit(df)
+ path = tempfile.mkdtemp()
+ km_path = path + "/km-pmml"
+ model.write().format("pmml").save(km_path)
+ pmml_text_list = self.sc.textFile(km_path).collect()
+ pmml_text = "\n".join(pmml_text_list)
+ self.assertIn("Apache Spark", pmml_text)
+ self.assertIn("PMML", pmml_text)
+
def _compare_params(self, m1, m2, param):
"""
Compare 2 ML Params instances for the given param, and assert both have the same param value
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org