You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/11/09 14:35:14 UTC

spark git commit: [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns

Repository: spark
Updated Branches:
  refs/heads/master 6793a3dac -> 77f74539e


[SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns

## What changes were proposed in this pull request?

Current ML's Bucketizer can only bin a column of continuous features. If a dataset has thousands of of continuous columns needed to bin, we will result in thousands of ML stages. It is inefficient regarding query planning and execution.

We should have a type of bucketizer that can bin a lot of columns all at once. It would need to accept an list of arrays of split points to correspond to the columns to bin, but it might make things more efficient by replacing thousands of stages with just one.

This current approach in this patch is to add a new `MultipleBucketizerInterface` for this purpose. `Bucketizer` now extends this new interface.

### Performance

Benchmarking using the test dataset provided in JIRA SPARK-20392 (blockbuster.csv).

The ML pipeline includes 2 `StringIndexer`s and 1 `MultipleBucketizer` or 137 `Bucketizer`s to bin 137 input columns with the same splits. Then count the time to transform the dataset.

MultipleBucketizer: 3352 ms
Bucketizer: 51512 ms

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #17819 from viirya/SPARK-20542.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/77f74539
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/77f74539
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/77f74539

Branch: refs/heads/master
Commit: 77f74539ec7a445e24736029fb198b48ffd50ea9
Parents: 6793a3d
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Nov 9 16:35:06 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu Nov 9 16:35:06 2017 +0200

----------------------------------------------------------------------
 .../examples/ml/JavaBucketizerExample.java      |  41 ++++
 .../spark/examples/ml/BucketizerExample.scala   |  36 ++-
 .../apache/spark/ml/feature/Bucketizer.scala    | 122 ++++++++--
 .../org/apache/spark/ml/param/params.scala      |  39 +++
 .../ml/param/shared/SharedParamsCodeGen.scala   |   1 +
 .../spark/ml/param/shared/sharedParams.scala    |  17 ++
 .../spark/ml/feature/JavaBucketizerSuite.java   |  35 +++
 .../spark/ml/feature/BucketizerSuite.scala      | 239 ++++++++++++++++++-
 .../org/apache/spark/ml/param/ParamsSuite.scala |  38 ++-
 .../scala/org/apache/spark/sql/Dataset.scala    |  21 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  28 +++
 11 files changed, 592 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
index f009938..3e49bf0 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
@@ -33,6 +33,13 @@ import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 // $example off$
 
+/**
+ * An example for Bucketizer.
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaBucketizerExample
+ * </pre>
+ */
 public class JavaBucketizerExample {
   public static void main(String[] args) {
     SparkSession spark = SparkSession
@@ -68,6 +75,40 @@ public class JavaBucketizerExample {
     bucketedData.show();
     // $example off$
 
+    // $example on$
+    // Bucketize multiple columns at one pass.
+    double[][] splitsArray = {
+      {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY},
+      {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY}
+    };
+
+    List<Row> data2 = Arrays.asList(
+      RowFactory.create(-999.9, -999.9),
+      RowFactory.create(-0.5, -0.2),
+      RowFactory.create(-0.3, -0.1),
+      RowFactory.create(0.0, 0.0),
+      RowFactory.create(0.2, 0.4),
+      RowFactory.create(999.9, 999.9)
+    );
+    StructType schema2 = new StructType(new StructField[]{
+      new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()),
+      new StructField("features2", DataTypes.DoubleType, false, Metadata.empty())
+    });
+    Dataset<Row> dataFrame2 = spark.createDataFrame(data2, schema2);
+
+    Bucketizer bucketizer2 = new Bucketizer()
+      .setInputCols(new String[] {"features1", "features2"})
+      .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"})
+      .setSplitsArray(splitsArray);
+    // Transform original data into its bucket index.
+    Dataset<Row> bucketedData2 = bucketizer2.transform(dataFrame2);
+
+    System.out.println("Bucketizer output with [" +
+      (bucketizer2.getSplitsArray()[0].length-1) + ", " +
+      (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column");
+    bucketedData2.show();
+    // $example off$
+
     spark.stop();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
index 04e4ecc..7e65f9c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
@@ -22,7 +22,13 @@ package org.apache.spark.examples.ml
 import org.apache.spark.ml.feature.Bucketizer
 // $example off$
 import org.apache.spark.sql.SparkSession
-
+/**
+ * An example for Bucketizer.
+ * Run with
+ * {{{
+ * bin/run-example ml.BucketizerExample
+ * }}}
+ */
 object BucketizerExample {
   def main(args: Array[String]): Unit = {
     val spark = SparkSession
@@ -48,6 +54,34 @@ object BucketizerExample {
     bucketedData.show()
     // $example off$
 
+    // $example on$
+    val splitsArray = Array(
+      Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+      Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity))
+
+    val data2 = Array(
+      (-999.9, -999.9),
+      (-0.5, -0.2),
+      (-0.3, -0.1),
+      (0.0, 0.0),
+      (0.2, 0.4),
+      (999.9, 999.9))
+    val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2")
+
+    val bucketizer2 = new Bucketizer()
+      .setInputCols(Array("features1", "features2"))
+      .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
+      .setSplitsArray(splitsArray)
+
+    // Transform original data into its bucket index.
+    val bucketedData2 = bucketizer2.transform(dataFrame2)
+
+    println(s"Bucketizer output with [" +
+      s"${bucketizer2.getSplitsArray(0).length-1}, " +
+      s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column")
+    bucketedData2.show()
+    // $example off$
+
     spark.stop()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 6a11a75..e07f2a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.ml.Model
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql._
 import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -32,12 +32,16 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 
 /**
- * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
+ * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
+ * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
+ * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
+ * only used for single column usage, and `splitsArray` is for multiple columns.
  */
 @Since("1.4.0")
 final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
   extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
-    with DefaultParamsWritable {
+    with HasInputCols with HasOutputCols with DefaultParamsWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("bucketizer"))
@@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
   /**
    * Param for how to handle invalid entries. Options are 'skip' (filter out rows with
    * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
-   * additional bucket).
+   * additional bucket). Note that in the multiple column case, the invalid handling is applied
+   * to all columns. That said for 'error' it will throw an error if any invalids are found in
+   * any column, for 'skip' it will skip rows with any invalids in any columns, etc.
    * Default: "error"
    * @group param
    */
@@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
   def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
   setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
 
+  /**
+   * Parameter for specifying multiple splits parameters. Each element in this array can be used to
+   * map continuous features into buckets.
+   *
+   * @group param
+   */
+  @Since("2.3.0")
+  val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
+    "The array of split points for mapping continuous features into buckets for multiple " +
+      "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
+      "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
+      "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
+      "explicitly provided to cover all Double values; otherwise, values outside the splits " +
+      "specified will be treated as errors.",
+    Bucketizer.checkSplitsArray)
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getSplitsArray: Array[Array[Double]] = $(splitsArray)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  /**
+   * Determines whether this `Bucketizer` is going to map multiple columns. If and only if
+   * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
+   * by `inputCol`. A warning will be printed if both are set.
+   */
+  private[feature] def isBucketizeMultipleColumns(): Boolean = {
+    if (isSet(inputCols) && isSet(inputCol)) {
+      logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
+        "`Bucketizer` only map one column specified by `inputCol`")
+      false
+    } else if (isSet(inputCols)) {
+      true
+    } else {
+      false
+    }
+  }
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
-    transformSchema(dataset.schema)
+    val transformedSchema = transformSchema(dataset.schema)
+
     val (filteredDataset, keepInvalid) = {
       if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
         // "skip" NaN option is set, will filter out NaN values in the dataset
@@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
       }
     }
 
-    val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
-      Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
-    }.withName("bucketizer")
+    val seqOfSplits = if (isBucketizeMultipleColumns()) {
+      $(splitsArray).toSeq
+    } else {
+      Seq($(splits))
+    }
 
-    val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
-    val newField = prepOutputField(filteredDataset.schema)
-    filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
+    val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
+      udf { (feature: Double) =>
+        Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
+      }.withName(s"bucketizer_$idx")
+    }
+
+    val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
+      ($(inputCols).toSeq, $(outputCols).toSeq)
+    } else {
+      (Seq($(inputCol)), Seq($(outputCol)))
+    }
+    val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
+      bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
+    }
+    val metadata = outputColumns.map { col =>
+      transformedSchema(col).metadata
+    }
+    filteredDataset.withColumns(outputColumns, newCols, metadata)
   }
 
-  private def prepOutputField(schema: StructType): StructField = {
-    val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
-    val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
+  private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
+    val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
+    val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
       values = Some(buckets))
     attr.toStructField()
   }
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkNumericType(schema, $(inputCol))
-    SchemaUtils.appendColumn(schema, prepOutputField(schema))
+    if (isBucketizeMultipleColumns()) {
+      var transformedSchema = schema
+      $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
+        SchemaUtils.checkNumericType(transformedSchema, inputCol)
+        transformedSchema = SchemaUtils.appendColumn(transformedSchema,
+          prepOutputField($(splitsArray)(idx), outputCol))
+      }
+      transformedSchema
+    } else {
+      SchemaUtils.checkNumericType(schema, $(inputCol))
+      SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
+    }
   }
 
   @Since("1.4.1")
@@ -164,6 +247,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
   }
 
   /**
+   * Check each splits in the splits array.
+   */
+  private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
+    splitsArray.forall(checkSplits(_))
+  }
+
+  /**
    * Binary searching in several buckets to place each data point.
    * @param splits array of split points
    * @param feature data point

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ac68b82..8985f2a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -492,6 +492,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
 
 /**
  * :: DeveloperApi ::
+ * Specialized version of `Param[Array[Array[Double]]]` for Java.
+ */
+@DeveloperApi
+class DoubleArrayArrayParam(
+    parent: Params,
+    name: String,
+    doc: String,
+    isValid: Array[Array[Double]] => Boolean)
+  extends Param[Array[Array[Double]]](parent, name, doc, isValid) {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue)
+
+  /** Creates a param pair with a `java.util.List` of values (for Java and Python). */
+  def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] =
+    w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray)
+
+  override def jsonEncode(value: Array[Array[Double]]): String = {
+    import org.json4s.JsonDSL._
+    compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode))))
+  }
+
+  override def jsonDecode(json: String): Array[Array[Double]] = {
+    parse(json) match {
+      case JArray(values) =>
+        values.map {
+          case JArray(values) =>
+            values.map(DoubleParam.jValueDecode).toArray
+          case _ =>
+            throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
+        }.toArray
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
  * Specialized version of `Param[Array[Int]]` for Java.
  */
 @DeveloperApi

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index a932d28..20a1db8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
+      ParamDesc[Array[String]]("outputCols", "output column names"),
       ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
         "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
         "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index e6bdf52..0d5fb28 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -258,6 +258,23 @@ trait HasOutputCol extends Params {
 }
 
 /**
+ * Trait for shared param outputCols. This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasOutputCols extends Params {
+
+  /**
+   * Param for output column names.
+   * @group param
+   */
+  final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")
+
+  /** @group getParam */
+  final def getOutputCols: Array[String] = $(outputCols)
+}
+
+/**
  * Trait for shared param checkpointInterval. This trait may be changed or
  * removed between minor versions.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 8763938..e65265b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -61,4 +61,39 @@ public class JavaBucketizerSuite extends SharedSparkSession {
       Assert.assertTrue((index >= 0) && (index <= 1));
     }
   }
+
+  @Test
+  public void bucketizerMultipleColumnsTest() {
+    double[][] splitsArray = {
+      {-0.5, 0.0, 0.5},
+      {-0.5, 0.0, 0.2, 0.5}
+    };
+
+    StructType schema = new StructType(new StructField[]{
+      new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
+      new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
+    });
+    Dataset<Row> dataset = spark.createDataFrame(
+      Arrays.asList(
+        RowFactory.create(-0.5, -0.5),
+        RowFactory.create(-0.3, -0.3),
+        RowFactory.create(0.0, 0.0),
+        RowFactory.create(0.2, 0.3)),
+      schema);
+
+    Bucketizer bucketizer = new Bucketizer()
+      .setInputCols(new String[] {"feature1", "feature2"})
+      .setOutputCols(new String[] {"result1", "result2"})
+      .setSplitsArray(splitsArray);
+
+    List<Row> result = bucketizer.transform(dataset).select("result1", "result2").collectAsList();
+
+    for (Row r : result) {
+      double index1 = r.getDouble(0);
+      Assert.assertTrue((index1 >= 0) && (index1 <= 1));
+
+      double index2 = r.getDouble(1);
+      Assert.assertTrue((index2 >= 0) && (index2 <= 2));
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 420fb17..748dbd1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.ml.feature
 import scala.util.Random
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.Pipeline
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
@@ -187,6 +188,220 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       }
     }
   }
+
+  test("multiple columns: Bucket continuous features, without -inf,inf") {
+    // Check a set of valid feature values.
+    val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5))
+    val validData1 = Array(-0.5, -0.3, 0.0, 0.2)
+    val validData2 = Array(0.5, 0.3, 0.0, -0.1)
+    val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0)
+    val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0)
+
+    val data = (0 until validData1.length).map { idx =>
+      (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+    }
+    val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+    val bucketizer1: Bucketizer = new Bucketizer()
+      .setInputCols(Array("feature1", "feature2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setSplitsArray(splits)
+
+    assert(bucketizer1.isBucketizeMultipleColumns())
+
+    bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
+    BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
+      Seq("result1", "result2"),
+      Seq("expected1", "expected2"))
+
+    // Check for exceptions when using a set of invalid feature values.
+    val invalidData1 = Array(-0.9) ++ validData1
+    val invalidData2 = Array(0.51) ++ validData1
+    val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx")
+
+    val bucketizer2: Bucketizer = new Bucketizer()
+      .setInputCols(Array("feature"))
+      .setOutputCols(Array("result"))
+      .setSplitsArray(Array(splits(0)))
+
+    assert(bucketizer2.isBucketizeMultipleColumns())
+
+    withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
+      intercept[SparkException] {
+        bucketizer2.transform(badDF1).collect()
+      }
+    }
+    val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx")
+    withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
+      intercept[SparkException] {
+        bucketizer2.transform(badDF2).collect()
+      }
+    }
+  }
+
+  test("multiple columns: Bucket continuous features, with -inf,inf") {
+    val splits = Array(
+      Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+      Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity))
+
+    val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
+    val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5)
+    val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
+    val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0)
+
+    val data = (0 until validData1.length).map { idx =>
+      (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+    }
+    val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCols(Array("feature1", "feature2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setSplitsArray(splits)
+
+    assert(bucketizer.isBucketizeMultipleColumns())
+
+    BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
+      Seq("result1", "result2"),
+      Seq("expected1", "expected2"))
+  }
+
+  test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") {
+    val splits = Array(
+      Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+      Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity))
+
+    val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN)
+    val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN)
+    val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0)
+    val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0)
+
+    val data = (0 until validData1.length).map { idx =>
+      (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+    }
+    val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+    val bucketizer: Bucketizer = new Bucketizer()
+      .setInputCols(Array("feature1", "feature2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setSplitsArray(splits)
+
+    assert(bucketizer.isBucketizeMultipleColumns())
+
+    bucketizer.setHandleInvalid("keep")
+    BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
+      Seq("result1", "result2"),
+      Seq("expected1", "expected2"))
+
+    bucketizer.setHandleInvalid("skip")
+    val skipResults1: Array[Double] = bucketizer.transform(dataFrame)
+      .select("result1").as[Double].collect()
+    assert(skipResults1.length === 7)
+    assert(skipResults1.forall(_ !== 4.0))
+
+    val skipResults2: Array[Double] = bucketizer.transform(dataFrame)
+      .select("result2").as[Double].collect()
+    assert(skipResults2.length === 7)
+    assert(skipResults2.forall(_ !== 4.0))
+
+    bucketizer.setHandleInvalid("error")
+    withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") {
+      intercept[SparkException] {
+        bucketizer.transform(dataFrame).collect()
+      }
+    }
+  }
+
+  test("multiple columns: Bucket continuous features, with NaN splits") {
+    val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
+    withClue("Invalid NaN split was not caught during Bucketizer initialization") {
+      intercept[IllegalArgumentException] {
+        new Bucketizer().setSplitsArray(Array(splits))
+      }
+    }
+  }
+
+  test("multiple columns: read/write") {
+    val t = new Bucketizer()
+      .setInputCols(Array("myInputCol"))
+      .setOutputCols(Array("myOutputCol"))
+      .setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
+    assert(t.isBucketizeMultipleColumns())
+    testDefaultReadWrite(t)
+  }
+
+  test("Bucketizer in a pipeline") {
+    val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0))
+      .toDF("feature1", "feature2", "expected1", "expected2")
+
+    val bucket = new Bucketizer()
+      .setInputCols(Array("feature1", "feature2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
+
+    assert(bucket.isBucketizeMultipleColumns())
+
+    val pl = new Pipeline()
+      .setStages(Array(bucket))
+      .fit(df)
+    pl.transform(df).select("result1", "expected1", "result2", "expected2")
+
+    BucketizerSuite.checkBucketResults(pl.transform(df),
+      Seq("result1", "result2"), Seq("expected1", "expected2"))
+  }
+
+  test("Compare single/multiple column(s) Bucketizer in pipeline") {
+    val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0))
+      .toDF("feature1", "feature2", "expected1", "expected2")
+
+    val multiColsBucket = new Bucketizer()
+      .setInputCols(Array("feature1", "feature2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
+
+    val plForMultiCols = new Pipeline()
+      .setStages(Array(multiColsBucket))
+      .fit(df)
+
+    val bucketForCol1 = new Bucketizer()
+      .setInputCol("feature1")
+      .setOutputCol("result1")
+      .setSplits(Array(-0.5, 0.0, 0.5))
+    val bucketForCol2 = new Bucketizer()
+      .setInputCol("feature2")
+      .setOutputCol("result2")
+      .setSplits(Array(-0.5, 0.0, 0.5))
+
+    val plForSingleCol = new Pipeline()
+      .setStages(Array(bucketForCol1, bucketForCol2))
+      .fit(df)
+
+    val resultForSingleCol = plForSingleCol.transform(df)
+      .select("result1", "expected1", "result2", "expected2")
+      .collect()
+    val resultForMultiCols = plForMultiCols.transform(df)
+      .select("result1", "expected1", "result2", "expected2")
+      .collect()
+
+    resultForSingleCol.zip(resultForMultiCols).foreach {
+        case (rowForSingle, rowForMultiCols) =>
+          assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
+            rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
+            rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) &&
+            rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3))
+    }
+  }
+
+  test("Both inputCol and inputCols are set") {
+    val bucket = new Bucketizer()
+      .setInputCol("feature1")
+      .setOutputCol("result")
+      .setSplits(Array(-0.5, 0.0, 0.5))
+      .setInputCols(Array("feature1", "feature2"))
+
+    // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
+    assert(bucket.isBucketizeMultipleColumns() == false)
+  }
 }
 
 private object BucketizerSuite extends SparkFunSuite {
@@ -220,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite {
       i += 1
     }
   }
+
+  /** Checks if bucketized results match expected ones. */
+  def checkBucketResults(
+      bucketResult: DataFrame,
+      resultColumns: Seq[String],
+      expectedColumns: Seq[String]): Unit = {
+    assert(resultColumns.length == expectedColumns.length,
+      s"Given ${resultColumns.length} result columns doesn't match " +
+        s"${expectedColumns.length} expected columns.")
+    assert(resultColumns.length > 0, "At least one result and expected columns are needed.")
+
+    val allColumns = resultColumns ++ expectedColumns
+    bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach {
+      case row =>
+        for (idx <- 0 until row.length / 2) {
+          val result = row.getDouble(idx)
+          val expected = row.getDouble(idx + row.length / 2)
+          assert(result === expected, "The feature value is not correct after bucketing. " +
+            s"Expected $expected but found $result.")
+        }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 78a33e0..85198ad 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite {
     { // DoubleArrayParam
       val param = new DoubleArrayParam(dummy, "name", "doc")
       val values: Seq[Array[Double]] = Seq(
-         Array(),
-         Array(1.0),
-         Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
-           Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
+        Array(),
+        Array(1.0),
+        Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+          Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
       for (value <- values) {
         val json = param.jsonEncode(value)
         val decoded = param.jsonDecode(json)
@@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite {
       }
     }
 
+    { // DoubleArrayArrayParam
+      val param = new DoubleArrayArrayParam(dummy, "name", "doc")
+      val values: Seq[Array[Array[Double]]] = Seq(
+        Array(Array()),
+        Array(Array(1.0)),
+        Array(Array(1.0), Array(2.0)),
+        Array(
+          Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+            Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity),
+          Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0,
+            Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0)
+        ))
+
+      for (value <- values) {
+        val json = param.jsonEncode(value)
+        val decoded = param.jsonDecode(json)
+        assert(decoded.length === value.length)
+        decoded.zip(value).foreach { case (actualArray, expectedArray) =>
+          assert(actualArray.length === expectedArray.length)
+          actualArray.zip(expectedArray).foreach { case (actual, expected) =>
+            if (expected.isNaN) {
+              assert(actual.isNaN)
+            } else {
+              assert(actual === expected)
+            }
+          }
+        }
+      }
+    }
+
     { // StringArrayParam
       val param = new StringArrayParam(dummy, "name", "doc")
       val values: Seq[Array[String]] = Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bd99ec5..5eb2aff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2135,13 +2135,28 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Returns a new Dataset by adding a column with metadata.
+   * Returns a new Dataset by adding columns with metadata.
    */
-  private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
-    withColumn(colName, col.as(colName, metadata))
+  private[spark] def withColumns(
+      colNames: Seq[String],
+      cols: Seq[Column],
+      metadata: Seq[Metadata]): DataFrame = {
+    require(colNames.size == metadata.size,
+      s"The size of column names: ${colNames.size} isn't equal to " +
+        s"the size of metadata elements: ${metadata.size}")
+    val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) =>
+      col.as(colName, metadata)
+    }
+    withColumns(colNames, newCols)
   }
 
   /**
+   * Returns a new Dataset by adding a column with metadata.
+   */
+  private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame =
+    withColumns(Seq(colName), Seq(col), Seq(metadata))
+
+  /**
    * Returns a new Dataset with a column renamed.
    * This is a no-op if schema doesn't contain existingName.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/77f74539/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 17c88b0..31bfa77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -686,6 +686,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("withColumns: given metadata") {
+    def buildMetadata(num: Int): Seq[Metadata] = {
+      (0 until num).map { n =>
+        val builder = new MetadataBuilder
+        builder.putLong("key", n.toLong)
+        builder.build()
+      }
+    }
+
+    val df = testData.toDF().withColumns(
+      Seq("newCol1", "newCol2"),
+      Seq(col("key") + 1, col("key") + 2),
+      buildMetadata(2))
+
+    df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) =>
+      assert(col.metadata.getLong("key").toInt === idx)
+    }
+
+    val err = intercept[IllegalArgumentException] {
+      testData.toDF().withColumns(
+        Seq("newCol1", "newCol2"),
+        Seq(col("key") + 1, col("key") + 2),
+        buildMetadata(1))
+    }
+    assert(err.getMessage.contains(
+      "The size of column names: 2 isn't equal to the size of metadata elements: 1"))
+  }
+
   test("replace column using withColumn") {
     val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
     val df3 = df2.withColumn("x", df2("x") + 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org