You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by WeichenXu123 <gi...@git.apache.org> on 2018/01/12 01:20:49 UTC
[GitHub] spark pull request #20229: [SPARK-23045][ML][SparkR] Update RFormula to use ...
Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20229#discussion_r161120354
--- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---
@@ -230,16 +231,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
- var encoder = new OneHotEncoder()
- .setInputCol(indexed(term))
- .setOutputCol(encodedCol)
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
- encoder = encoder.setDropLast(false)
+ encoderStages += new OneHotEncoderEstimator(uid)
+ .setInputCols(Array(indexed(term)))
+ .setOutputCols(Array(encodedCol))
+ .setDropLast(false)
--- End diff --
Here can optimize. You can merge this multiple (probable) OHEs into one. like:
define:
```
val oneHotEncodeColumnsNotDropLast = ArrayBuffer[(String, String)]()
```
and:
```
if (!hasIntercept && !keepReferenceCategory) {
oneHotEncodeColumnsNotDropLast += indexed(term) -> encodedCol
} else {
oneHotEncodeColumns += indexed(term) -> encodedCol
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org