You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/01/16 20:57:03 UTC
spark git commit: [SPARK-23045][ML][SPARKR] Update RFormula to use
OneHotEncoderEstimator.
Repository: spark
Updated Branches:
refs/heads/master 12db365b4 -> 4371466b3
[SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator.
## What changes were proposed in this pull request?
RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming.
## How was this patch tested?
Unit tests.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Bago Amirbekian <ba...@databricks.com>
Closes #20229 from MrBago/rFormula.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4371466b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4371466b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4371466b
Branch: refs/heads/master
Commit: 4371466b3f06ca171b10568e776c9446f7bae6dd
Parents: 12db365
Author: Bago Amirbekian <ba...@databricks.com>
Authored: Tue Jan 16 12:56:57 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Jan 16 12:56:57 2018 -0800
----------------------------------------------------------------------
R/pkg/R/mllib_utils.R | 1 -
.../org/apache/spark/ml/feature/RFormula.scala | 20 ++++++--
.../apache/spark/ml/feature/RFormulaSuite.scala | 53 ++++++++++++--------
3 files changed, 46 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4371466b/R/pkg/R/mllib_utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 23dda42..a53c92c 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -130,4 +130,3 @@ read.ml <- function(path) {
stop("Unsupported model: ", jobj)
}
}
-
http://git-wip-us.apache.org/repos/asf/spark/blob/4371466b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f384ffb..1155ea5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
val encoderStages = ArrayBuffer[PipelineStage]()
+ val oneHotEncodeColumns = ArrayBuffer[(String, String)]()
val prefixesToRewrite = mutable.Map[String, String]()
val tempColumns = ArrayBuffer[String]()
@@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
- var encoder = new OneHotEncoder()
- .setInputCol(indexed(term))
- .setOutputCol(encodedCol)
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
- encoder = encoder.setDropLast(false)
+ encoderStages += new OneHotEncoderEstimator(uid)
+ .setInputCols(Array(indexed(term)))
+ .setOutputCols(Array(encodedCol))
+ .setDropLast(false)
keepReferenceCategory = true
+ } else {
+ oneHotEncodeColumns += indexed(term) -> encodedCol
}
- encoderStages += encoder
prefixesToRewrite(encodedCol + "_") = term + "_"
encodedCol
case Seq(term) =>
@@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
interactionCol
}
+ if (oneHotEncodeColumns.nonEmpty) {
+ val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip
+ encoderStages += new OneHotEncoderEstimator(uid)
+ .setInputCols(inputCols)
+ .setOutputCols(outputCols)
+ .setDropLast(true)
+ }
+
encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
http://git-wip-us.apache.org/repos/asf/spark/blob/4371466b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index f3f4b5a..bfe38d3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
+ def testRFormulaTransform[A: Encoder](
+ dataframe: DataFrame,
+ formulaModel: RFormulaModel,
+ expected: DataFrame): Unit = {
+ val (first +: rest) = expected.schema.fieldNames.toSeq
+ val expectedRows = expected.collect()
+ testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows =>
+ assert(rows === expectedRows)
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new RFormula())
}
@@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
assert(result.schema.toString == resultSchema.toString)
assert(resultSchema == expected.schema)
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}
test("features column already exists") {
@@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(7, 8.0, 9.0, Vectors.dense(8.0, 9.0))
).toDF("id", "a", "b", "features")
assert(result.schema.toString == resultSchema.toString)
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}
test("encodes string terms") {
@@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label")
assert(result.schema.toString == resultSchema.toString)
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, String, Int)](original, model, expected)
}
test("encodes string terms with string indexer order type") {
@@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
- assert(result.collect() === expected(idx).collect())
+ testRFormulaTransform[(Int, String, Int)](original, model, expected(idx))
idx += 1
}
}
@@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, String, Int)](original, model, expected)
}
test("formula w/o intercept, we should output reference category when encoding string terms") {
@@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result1.schema.toString == resultSchema1.toString)
- assert(result1.collect() === expected1.collect())
+ testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1)
val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
val expectedAttrs1 = new AttributeGroup(
@@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result2.schema.toString == resultSchema2.toString)
- assert(result2.collect() === expected2.collect())
+ testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2)
val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
val expectedAttrs2 = new AttributeGroup(
@@ -302,7 +313,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
val expected = Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
@@ -310,7 +320,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
).toDF("id", "a", "b", "features", "label")
// assert(result.schema.toString == resultSchema.toString)
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(String, String, Int)](original, model, expected)
}
test("force to index label even it is numeric type") {
@@ -319,7 +329,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
val expected = spark.createDataFrame(
Seq(
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
@@ -327,7 +336,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
).toDF("id", "a", "b", "features", "label")
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Double, String, Int)](original, model, expected)
}
test("attribute generation") {
@@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0)
).toDF("a", "b", "c", "d", "features", "label")
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
@@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)
).toDF("id", "a", "b", "features", "label")
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, String, Int)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
@@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
(3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")
- assert(result.collect() === expected.collect())
+ testRFormulaTransform[(Int, String, String)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
@@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
intercept[SparkException] {
formula1.fit(df1).transform(df2).collect()
}
- val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
- val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
+ val model1 = formula1.setHandleInvalid("skip").fit(df1)
+ val model2 = formula1.setHandleInvalid("keep").fit(df1)
val expected1 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
@@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")
- assert(result1.collect() === expected1.collect())
- assert(result2.collect() === expected2.collect())
+ testRFormulaTransform[(Int, String, String)](df2, model1, expected1)
+ testRFormulaTransform[(Int, String, String)](df2, model2, expected2)
// Handle unseen labels.
val formula2 = new RFormula().setFormula("b ~ a + id")
intercept[SparkException] {
formula2.fit(df1).transform(df2).collect()
}
- val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
- val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
+ val model3 = formula2.setHandleInvalid("skip").fit(df1)
+ val model4 = formula2.setHandleInvalid("keep").fit(df1)
val expected3 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
@@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
).toDF("id", "a", "b", "features", "label")
- assert(result3.collect() === expected3.collect())
- assert(result4.collect() === expected4.collect())
+ testRFormulaTransform[(Int, String, String)](df2, model3, expected3)
+ testRFormulaTransform[(Int, String, String)](df2, model4, expected4)
}
test("Use Vectors as inputs to formula.") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org