You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/03/12 04:52:40 UTC

[GitHub] [spark] zhengruifeng commented on a change in pull request #27882: [SPARK-31127][ML] Add abstract Selector

zhengruifeng commented on a change in pull request #27882: [SPARK-31127][ML] Add abstract Selector
URL: https://github.com/apache/spark/pull/27882#discussion_r391395961
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
 ##########
 @@ -153,102 +48,75 @@ private[feature] trait ChiSqSelectorParams extends Params
  */
 @Since("1.6.0")
 final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
-  extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {
+  extends Selector[ChiSqSelectorModel] {
 
   @Since("1.6.0")
   def this() = this(Identifiable.randomUID("chiSqSelector"))
 
   /** @group setParam */
   @Since("1.6.0")
-  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+  override def setNumTopFeatures(value: Int): this.type = super.setNumTopFeatures(value)
 
   /** @group setParam */
   @Since("2.1.0")
-  def setPercentile(value: Double): this.type = set(percentile, value)
+  override def setPercentile(value: Double): this.type = super.setPercentile(value)
 
   /** @group setParam */
   @Since("2.1.0")
-  def setFpr(value: Double): this.type = set(fpr, value)
+  override def setFpr(value: Double): this.type = super.setFpr(value)
 
   /** @group setParam */
   @Since("2.2.0")
-  def setFdr(value: Double): this.type = set(fdr, value)
+  override def setFdr(value: Double): this.type = super.setFdr(value)
 
   /** @group setParam */
   @Since("2.2.0")
-  def setFwe(value: Double): this.type = set(fwe, value)
+  override def setFwe(value: Double): this.type = super.setFwe(value)
 
   /** @group setParam */
   @Since("2.1.0")
-  def setSelectorType(value: String): this.type = set(selectorType, value)
+  override def setSelectorType(value: String): this.type = super.setSelectorType(value)
 
   /** @group setParam */
   @Since("1.6.0")
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+  override def setFeaturesCol(value: String): this.type = super.setFeaturesCol(value)
 
   /** @group setParam */
   @Since("1.6.0")
-  def setOutputCol(value: String): this.type = set(outputCol, value)
+  override def setOutputCol(value: String): this.type = super.setOutputCol(value)
 
   /** @group setParam */
   @Since("1.6.0")
-  def setLabelCol(value: String): this.type = set(labelCol, value)
-
-  @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
-    transformSchema(dataset.schema, logging = true)
-    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
-      case Row(label: Double, features: Vector) =>
-        LabeledPoint(label, features)
-    }
+  override def setLabelCol(value: String): this.type = super.setLabelCol(value)
 
-    val testResult = ChiSquareTest.testChiSquare(dataset, getFeaturesCol, getLabelCol)
-      .zipWithIndex
-    val features = $(selectorType) match {
-      case "numTopFeatures" =>
-        testResult
-          .sortBy { case (res, _) => res.pValue }
-          .take(getNumTopFeatures)
-      case "percentile" =>
-        testResult
-          .sortBy { case (res, _) => res.pValue }
-          .take((testResult.length * getPercentile).toInt)
-      case "fpr" =>
-        testResult
-          .filter { case (res, _) => res.pValue < getFpr }
-      case "fdr" =>
-        // This uses the Benjamini-Hochberg procedure.
-        // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure
-        val tempRes = testResult
-          .sortBy { case (res, _) => res.pValue }
-        val selected = tempRes
-          .zipWithIndex
-          .filter { case ((res, _), index) =>
-            res.pValue <= getFdr * (index + 1) / testResult.length }
-        if (selected.isEmpty) {
-          Array.empty[(SelectionTestResult, Int)]
-        } else {
-          val maxIndex = selected.map(_._2).max
-          tempRes.take(maxIndex + 1)
-        }
-      case "fwe" =>
-        testResult
-          .filter { case (res, _) => res.pValue < getFwe / testResult.length }
-      case errorType =>
-        throw new IllegalStateException(s"Unknown Selector Type: $errorType")
-    }
-    val indices = features.map { case (_, index) => index }
-    copyValues(new ChiSqSelectorModel(uid, indices.sorted)
-      .setParent(this))
+  /**
+   * get the SelectionTestResult for every feature against the label
+   */
+  @Since("3.1.0")
+  protected[this] override def getSelectionTestResult(dataset: Dataset[_]):
+  Array[SelectionTestResult] = {
+    SelectionTest.chiSquareTest(dataset, getFeaturesCol, getLabelCol)
   }
 
-  @Since("1.6.0")
-  override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
-    SchemaUtils.checkNumericType(schema, $(labelCol))
-    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+  /**
+   * Create a new instance of concrete SelectorModel.
+   * @param indices The indices of the selected features
+   * @param pValues The pValues of the selected features
+   * @param statistics The chi square statistic of the selected features
+   * @return A new SelectorModel instance
+   */
+  @Since("3.1.0")
+  protected[this] def createSelectorModel(
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org