You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/08/12 06:57:16 UTC
spark git commit: [SPARK-17025][ML][PYTHON] Persistence for Pipelines
with Python-only Stages
Repository: spark
Updated Branches:
refs/heads/master b0bdfce9c -> 35db3b9fe
[SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages
## What changes were proposed in this pull request?
Implemented a Python-only persistence framework for pipelines containing stages that cannot be saved using Java.
## How was this patch tested?
Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded the pipeline. The loaded pipeline was compared against the original using _compare_pipelines() in tests.py.
Author: Ajay Saini <aj...@gmail.com>
Closes #18888 from ajaysaini725/PythonPipelines.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35db3b9f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35db3b9f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35db3b9f
Branch: refs/heads/master
Commit: 35db3b9fe38dadfb8afb0b0857c09f83196398be
Parents: b0bdfce
Author: Ajay Saini <aj...@gmail.com>
Authored: Fri Aug 11 23:57:08 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Aug 11 23:57:08 2017 -0700
----------------------------------------------------------------------
python/pyspark/ml/pipeline.py | 156 +++++++++++++++++++++++++++++++++++--
python/pyspark/ml/tests.py | 35 ++++++++-
2 files changed, 183 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a8dc76b..0975302 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -16,6 +16,7 @@
#
import sys
+import os
if sys.version > '3':
basestring = str
@@ -23,7 +24,7 @@ if sys.version > '3':
from pyspark import since, keyword_only, SparkContext
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc
@@ -130,13 +131,16 @@ class Pipeline(Estimator, MLReadable, MLWritable):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
+ if allStagesAreJava:
+ return JavaMLWriter(self)
+ return PipelineWriter(self)
@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
+ return PipelineReader(cls)
@classmethod
def _from_java(cls, java_stage):
@@ -172,6 +176,76 @@ class Pipeline(Estimator, MLReadable, MLWritable):
@inherit_doc
+class PipelineWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
+ """
+
+ def __init__(self, instance):
+ super(PipelineWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ stages = self.instance.getStages()
+ PipelineSharedReadWrite.validateStages(stages)
+ PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
+ """
+
+ def __init__(self, cls):
+ super(PipelineReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
+ return JavaMLReader(self.cls).load(path)
+ else:
+ uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+ return Pipeline(stages=stages)._resetUid(uid)
+
+
+@inherit_doc
+class PipelineModelWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
+ """
+
+ def __init__(self, instance):
+ super(PipelineModelWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ stages = self.instance.stages
+ PipelineSharedReadWrite.validateStages(stages)
+ PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineModelReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
+ """
+
+ def __init__(self, cls):
+ super(PipelineModelReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
+ return JavaMLReader(self.cls).load(path)
+ else:
+ uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+ return PipelineModel(stages=stages)._resetUid(uid)
+
+
+@inherit_doc
class PipelineModel(Model, MLReadable, MLWritable):
"""
Represents a compiled pipeline with transformers and fitted models.
@@ -204,13 +278,16 @@ class PipelineModel(Model, MLReadable, MLWritable):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
+ if allStagesAreJava:
+ return JavaMLWriter(self)
+ return PipelineModelWriter(self)
@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
+ return PipelineModelReader(cls)
@classmethod
def _from_java(cls, java_stage):
@@ -242,3 +319,72 @@ class PipelineModel(Model, MLReadable, MLWritable):
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj
+
+
+@inherit_doc
+class PipelineSharedReadWrite():
+ """
+ .. note:: DeveloperApi
+
+ Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
+ :py:class:`Pipeline` and :py:class:`PipelineModel`
+
+ .. versionadded:: 2.3.0
+ """
+
+ @staticmethod
+ def checkStagesForJava(stages):
+ return all(isinstance(stage, JavaMLWritable) for stage in stages)
+
+ @staticmethod
+ def validateStages(stages):
+ """
+ Check that all stages are Writable
+ """
+ for stage in stages:
+ if not isinstance(stage, MLWritable):
+ raise ValueError("Pipeline write will fail on this pipeline " +
+ "because stage %s of type %s is not MLWritable",
+ stage.uid, type(stage))
+
+ @staticmethod
+ def saveImpl(instance, stages, sc, path):
+ """
+ Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+ - save metadata to path/metadata
+ - save stages to stages/IDX_UID
+ """
+ stageUids = [stage.uid for stage in stages]
+ jsonParams = {'stageUids': stageUids, 'language': 'Python'}
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
+ stagesDir = os.path.join(path, "stages")
+ for index, stage in enumerate(stages):
+ stage.write().save(PipelineSharedReadWrite
+ .getStagePath(stage.uid, index, len(stages), stagesDir))
+
+ @staticmethod
+ def load(metadata, sc, path):
+ """
+ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+
+ :return: (UID, list of stages)
+ """
+ stagesDir = os.path.join(path, "stages")
+ stageUids = metadata['paramMap']['stageUids']
+ stages = []
+ for index, stageUid in enumerate(stageUids):
+ stagePath = \
+ PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
+ stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
+ stages.append(stage)
+ return (metadata['uid'], stages)
+
+ @staticmethod
+ def getStagePath(stageUid, stageIdx, numStages, stagesDir):
+ """
+ Get path for saving the given stage.
+ """
+ stageIdxDigits = len(str(numStages))
+ stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
+ stagePath = os.path.join(stagesDir, stageDir)
+ return stagePath
http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6aecc7f..0495973 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -123,7 +123,7 @@ class MockTransformer(Transformer, HasFake):
return dataset
-class MockUnaryTransformer(UnaryTransformer):
+class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
"data in a DataFrame",
@@ -150,7 +150,7 @@ class MockUnaryTransformer(UnaryTransformer):
def validateInputType(self, inputType):
if inputType != DoubleType():
raise TypeError("Bad input type: {}. ".format(inputType) +
- "Requires Integer.")
+ "Requires Double.")
class MockEstimator(Estimator, HasFake):
@@ -1063,7 +1063,7 @@ class PersistenceTest(SparkSessionTestCase):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaParams):
+ if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self._compare_params(m1, m2, p)
@@ -1142,6 +1142,35 @@ class PersistenceTest(SparkSessionTestCase):
except OSError:
pass
+ def test_python_transformer_pipeline_persistence(self):
+ """
+ Pipeline[MockUnaryTransformer, Binarizer]
+ """
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = self.spark.range(0, 10).toDF('input')
+ tf = MockUnaryTransformer(shiftVal=2)\
+ .setInputCol("input").setOutputCol("shiftedInput")
+ tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
+ pl = Pipeline(stages=[tf, tf2])
+ model = pl.fit(df)
+
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org