You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/08/04 08:01:41 UTC
spark git commit: [SPARK-21633][ML][PYTHON] UnaryTransformer in Python
Repository: spark
Updated Branches:
refs/heads/master 25826c77d -> 1347b2a69
[SPARK-21633][ML][PYTHON] UnaryTransformer in Python
## What changes were proposed in this pull request?
Implemented UnaryTransformer in Python.
## How was this patch tested?
This patch was tested by creating a MockUnaryTransformer class in the unit tests that extends UnaryTransformer and testing that the transform function produced correct output.
Author: Ajay Saini <aj...@gmail.com>
Closes #18746 from ajaysaini725/AddPythonUnaryTransformer.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1347b2a6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1347b2a6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1347b2a6
Branch: refs/heads/master
Commit: 1347b2a697aa798c04b39fbb352efc735aa42ea3
Parents: 25826c7
Author: Ajay Saini <aj...@gmail.com>
Authored: Fri Aug 4 01:01:32 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Aug 4 01:01:32 2017 -0700
----------------------------------------------------------------------
python/pyspark/ml/__init__.py | 4 +--
python/pyspark/ml/base.py | 56 ++++++++++++++++++++++++++++++++++
python/pyspark/ml/tests.py | 62 +++++++++++++++++++++++++++++++++++++-
3 files changed, 119 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 1d42d49..129d7d6 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -19,7 +19,7 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
-from pyspark.ml.base import Estimator, Model, Transformer
+from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
-__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
+__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/base.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 339e5d6..a6767ce 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -17,9 +17,14 @@
from abc import ABCMeta, abstractmethod
+import copy
+
from pyspark import since
from pyspark.ml.param import Params
+from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
+from pyspark.sql.functions import udf
+from pyspark.sql.types import StructField, StructType, DoubleType
@inherit_doc
@@ -116,3 +121,54 @@ class Model(Transformer):
"""
__metaclass__ = ABCMeta
+
+
+@inherit_doc
+class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
+ """
+ Abstract class for transformers that take one input column, apply transformation,
+ and output the result as a new column.
+
+ .. versionadded:: 2.3.0
+ """
+
+ @abstractmethod
+ def createTransformFunc(self):
+ """
+ Creates the transform function using the given param map. The input param map already takes
+ account of the embedded param map. So the param values should be determined
+ solely by the input param map.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def outputDataType(self):
+ """
+ Returns the data type of the output column.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def validateInputType(self, inputType):
+ """
+ Validates the input type. Throw an exception if it is invalid.
+ """
+ raise NotImplementedError()
+
+ def transformSchema(self, schema):
+ inputType = schema[self.getInputCol()].dataType
+ self.validateInputType(inputType)
+ if self.getOutputCol() in schema.names:
+ raise ValueError("Output column %s already exists." % self.getOutputCol())
+ outputFields = copy.copy(schema.fields)
+ outputFields.append(StructField(self.getOutputCol(),
+ self.outputDataType(),
+ nullable=False))
+ return StructType(outputFields)
+
+ def _transform(self, dataset):
+ self.transformSchema(dataset.schema)
+ transformUDF = udf(self.createTransformFunc(), self.outputDataType())
+ transformedDataset = dataset.withColumn(self.getOutputCol(),
+ transformUDF(dataset[self.getInputCol()]))
+ return transformedDataset
http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7ee2c2f..3bd4d37 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -45,7 +45,7 @@ from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
import inspect
from pyspark import keyword_only, SparkContext
-from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
+from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer
from pyspark.ml.classification import *
from pyspark.ml.clustering import *
from pyspark.ml.common import _java2py, _py2java
@@ -66,6 +66,7 @@ from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
+from pyspark.sql.types import DoubleType, IntegerType
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -121,6 +122,36 @@ class MockTransformer(Transformer, HasFake):
return dataset
+class MockUnaryTransformer(UnaryTransformer):
+
+ shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
+ "data in a DataFrame",
+ typeConverter=TypeConverters.toFloat)
+
+ def __init__(self, shiftVal=1):
+ super(MockUnaryTransformer, self).__init__()
+ self._setDefault(shift=1)
+ self._set(shift=shiftVal)
+
+ def getShift(self):
+ return self.getOrDefault(self.shift)
+
+ def setShift(self, shift):
+ self._set(shift=shift)
+
+ def createTransformFunc(self):
+ shiftVal = self.getShift()
+ return lambda x: x + shiftVal
+
+ def outputDataType(self):
+ return DoubleType()
+
+ def validateInputType(self, inputType):
+ if inputType != DoubleType():
+ raise TypeError("Bad input type: {}. ".format(inputType) +
+ "Requires Integer.")
+
+
class MockEstimator(Estimator, HasFake):
def __init__(self):
@@ -2008,6 +2039,35 @@ class ChiSquareTestTests(SparkSessionTestCase):
self.assertTrue(all(field in fieldNames for field in expectedFields))
+class UnaryTransformerTests(SparkSessionTestCase):
+
+ def test_unary_transformer_validate_input_type(self):
+ shiftVal = 3
+ transformer = MockUnaryTransformer(shiftVal=shiftVal)\
+ .setInputCol("input").setOutputCol("output")
+
+ # should not raise any errors
+ transformer.validateInputType(DoubleType())
+
+ with self.assertRaises(TypeError):
+ # passing the wrong input type should raise an error
+ transformer.validateInputType(IntegerType())
+
+ def test_unary_transformer_transform(self):
+ shiftVal = 3
+ transformer = MockUnaryTransformer(shiftVal=shiftVal)\
+ .setInputCol("input").setOutputCol("output")
+
+ df = self.spark.range(0, 10).toDF('input')
+ df = df.withColumn("input", df.input.cast(dataType="double"))
+
+ transformed_df = transformer.transform(df)
+ results = transformed_df.select("input", "output").collect()
+
+ for res in results:
+ self.assertEqual(res.input + shiftVal, res.output)
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org