You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2019/01/22 15:35:26 UTC

[spark] branch master updated: [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 06792af  [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark
06792af is described below

commit 06792afd4c9c719df4af34b0768a999271383330
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Tue Jan 22 09:34:59 2019 -0600

    [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark
    
    ## What changes were proposed in this pull request?
    
    Add PMML export support for ML KMeans to PySpark.
    
    ## How was this patch tested?
    
    Add tests in ml.tests.PersistenceTest.
    
    Closes #23592 from huaxingao/spark-16838.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <se...@databricks.com>
---
 python/pyspark/ml/clustering.py             |  2 +-
 python/pyspark/ml/tests/test_persistence.py | 37 +++++++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 5a776ae..b9c6bdf 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -323,7 +323,7 @@ class KMeansSummary(ClusteringSummary):
         return self._call_java("trainingCost")
 
 
-class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
+class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable):
     """
     Model fitted by KMeans.
 
diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
index 34d6870..63b0594 100644
--- a/python/pyspark/ml/tests/test_persistence.py
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -23,6 +23,7 @@ import unittest
 from pyspark.ml import Transformer
 from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
     OneVsRestModel
+from pyspark.ml.clustering import KMeans
 from pyspark.ml.feature import Binarizer, HashingTF, PCA
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.param import Params
@@ -89,6 +90,42 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
+    def test_kmeans(self):
+        kmeans = KMeans(k=2, seed=1)
+        path = tempfile.mkdtemp()
+        km_path = path + "/km"
+        kmeans.save(km_path)
+        kmeans2 = KMeans.load(km_path)
+        self.assertEqual(kmeans.uid, kmeans2.uid)
+        self.assertEqual(type(kmeans.uid), type(kmeans2.uid))
+        self.assertEqual(kmeans2.uid, kmeans2.k.parent,
+                         "Loaded KMeans instance uid (%s) did not match Param's uid (%s)"
+                         % (kmeans2.uid, kmeans2.k.parent))
+        self.assertEqual(kmeans._defaultParamMap[kmeans.k], kmeans2._defaultParamMap[kmeans2.k],
+                         "Loaded KMeans instance default params did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_kmean_pmml_basic(self):
+        # Most of the validation is done in the Scala side, here we just check
+        # that we output text rather than parquet (e.g. that the format flag
+        # was respected).
+        data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+                (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+        df = self.spark.createDataFrame(data, ["features"])
+        kmeans = KMeans(k=2, seed=1)
+        model = kmeans.fit(df)
+        path = tempfile.mkdtemp()
+        km_path = path + "/km-pmml"
+        model.write().format("pmml").save(km_path)
+        pmml_text_list = self.sc.textFile(km_path).collect()
+        pmml_text = "\n".join(pmml_text_list)
+        self.assertIn("Apache Spark", pmml_text)
+        self.assertIn("PMML", pmml_text)
+
     def _compare_params(self, m1, m2, param):
         """
         Compare 2 ML Params instances for the given param, and assert both have the same param value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org