You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/06 15:43:26 UTC
spark git commit: [SPARK-18210][ML] Pipeline.copy does not create an
instance with the same UID
Repository: spark
Updated Branches:
refs/heads/master 340f09d10 -> b89d0556d
[SPARK-18210][ML] Pipeline.copy does not create an instance with the same UID
## What changes were proposed in this pull request?
Motivation:
`org.apache.spark.ml.Pipeline.copy(extra: ParamMap)` does not create an instance with the same UID. It does not conform to the method specification from its base class `org.apache.spark.ml.param.Params.copy(extra: ParamMap)`
Solution:
- fix for Pipeline UID
- introduced new tests for `org.apache.spark.ml.Pipeline.copy`
- minor improvements in test for `org.apache.spark.ml.PipelineModel.copy`
## How was this patch tested?
Introduced new unit test: `org.apache.spark.ml.PipelineSuite."Pipeline.copy"`
Improved existing unit test: `org.apache.spark.ml.PipelineSuite."PipelineModel.copy"`
Author: Wojciech Szymanski <wk...@gmail.com>
Closes #15759 from wojtek-szymanski/SPARK-18210.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b89d0556
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b89d0556
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b89d0556
Branch: refs/heads/master
Commit: b89d0556dff0520ab35882382242fbfa7d9478eb
Parents: 340f09d
Author: Wojciech Szymanski <wk...@gmail.com>
Authored: Sun Nov 6 07:43:13 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sun Nov 6 07:43:13 2016 -0800
----------------------------------------------------------------------
.../scala/org/apache/spark/ml/Pipeline.scala | 2 +-
.../org/apache/spark/ml/PipelineSuite.scala | 22 ++++++++++++++++++--
2 files changed, 21 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b89d0556/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 195a93e..f406f8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") (
override def copy(extra: ParamMap): Pipeline = {
val map = extractParamMap(extra)
val newStages = map(stages).map(_.copy(extra))
- new Pipeline().setStages(newStages)
+ new Pipeline(uid).setStages(newStages)
}
@Since("1.2.0")
http://git-wip-us.apache.org/repos/asf/spark/blob/b89d0556/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 6413ca1..dafc6c2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
+ test("Pipeline.copy") {
+ val hashingTF = new HashingTF()
+ .setNumFeatures(100)
+ val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF))
+ val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10))
+
+ assert(copied.uid === pipeline.uid,
+ "copy should create an instance with the same UID")
+ assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+ "copy should handle extra stage params")
+ }
+
test("PipelineModel.copy") {
val hashingTF = new HashingTF()
.setNumFeatures(100)
- val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
+ val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF))
+ .setParent(new Pipeline())
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
- require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+
+ assert(copied.uid === model.uid,
+ "copy should create an instance with the same UID")
+ assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
"copy should handle extra stage params")
+ assert(copied.parent === model.parent,
+ "copy should create an instance with the same parent")
}
test("pipeline model constructors") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org