You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/08 01:33:27 UTC

spark git commit: [SPARK-19397][SQL] Make option names of LIBSVM and TEXT case insensitive

Repository: spark
Updated Branches:
  refs/heads/master 8df444403 -> e33aaa2ac


[SPARK-19397][SQL] Make option names of LIBSVM and TEXT case insensitive

### What changes were proposed in this pull request?
Prior to Spark 2.1, the option names are case sensitive for all the formats. Since Spark 2.1, the option key names become case insensitive except the format `Text` and `LibSVM `. This PR is to fix these issues.

Also, add a check to know whether the input option vector type is legal for `LibSVM`.

### How was this patch tested?
Added test cases

Author: gatorsmile <ga...@gmail.com>

Closes #16737 from gatorsmile/libSVMTextOptions.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e33aaa2a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e33aaa2a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e33aaa2a

Branch: refs/heads/master
Commit: e33aaa2ac53a6e17e160e4e63821450b3609033b
Parents: 8df4444
Author: gatorsmile <ga...@gmail.com>
Authored: Wed Feb 8 09:33:18 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Feb 8 09:33:18 2017 +0800

----------------------------------------------------------------------
 .../spark/ml/source/libsvm/LibSVMOptions.scala  | 51 ++++++++++++++++++++
 .../spark/ml/source/libsvm/LibSVMRelation.scala | 14 +++---
 .../ml/source/libsvm/LibSVMRelationSuite.scala  | 14 ++++++
 .../datasources/text/TextFileFormat.scala       |  5 +-
 .../datasources/text/TextOptions.scala          | 40 +++++++++++++++
 .../execution/datasources/text/TextSuite.scala  | 22 ++++++++-
 6 files changed, 136 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
new file mode 100644
index 0000000..e3c5b4d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMOptions.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.source.libsvm
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+
+/**
+ * Options for the LibSVM data source.
+ */
+private[libsvm] class LibSVMOptions(@transient private val parameters: CaseInsensitiveMap)
+  extends Serializable {
+
+  import LibSVMOptions._
+
+  def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
+
+  /**
+   * Number of features. If unspecified or nonpositive, the number of features will be determined
+   * automatically at the cost of one additional pass.
+   */
+  val numFeatures = parameters.get(NUM_FEATURES).map(_.toInt).filter(_ > 0)
+
+  val isSparse = parameters.getOrElse(VECTOR_TYPE, SPARSE_VECTOR_TYPE) match {
+    case SPARSE_VECTOR_TYPE => true
+    case DENSE_VECTOR_TYPE => false
+    case o => throw new IllegalArgumentException(s"Invalid value `$o` for parameter " +
+      s"`$VECTOR_TYPE`. Expected types are `sparse` and `dense`.")
+  }
+}
+
+private[libsvm] object LibSVMOptions {
+  val NUM_FEATURES = "numFeatures"
+  val VECTOR_TYPE = "vectorType"
+  val DENSE_VECTOR_TYPE = "dense"
+  val SPARSE_VECTOR_TYPE = "sparse"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 100b4bb..f68847a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -77,7 +77,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
       dataSchema.size != 2 ||
         !dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
         !dataSchema(1).dataType.sameType(new VectorUDT()) ||
-        !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
+        !(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
     ) {
       throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
     }
@@ -87,7 +87,8 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
       sparkSession: SparkSession,
       options: Map[String, String],
       files: Seq[FileStatus]): Option[StructType] = {
-    val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
+    val libSVMOptions = new LibSVMOptions(options)
+    val numFeatures: Int = libSVMOptions.numFeatures.getOrElse {
       // Infers number of features if the user doesn't specify (a valid) one.
       val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
       val path = if (dataFiles.length == 1) {
@@ -104,7 +105,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
     }
 
     val featuresMetadata = new MetadataBuilder()
-      .putLong("numFeatures", numFeatures)
+      .putLong(LibSVMOptions.NUM_FEATURES, numFeatures)
       .build()
 
     Some(
@@ -142,10 +143,11 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
       options: Map[String, String],
       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
     verifySchema(dataSchema)
-    val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
+    val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
     assert(numFeatures > 0)
 
-    val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
+    val libSVMOptions = new LibSVMOptions(options)
+    val isSparse = libSVMOptions.isSparse
 
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
@@ -173,7 +175,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
       val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
 
       points.map { pt =>
-        val features = if (sparse) pt.features.toSparse else pt.features.toDense
+        val features = if (isSparse) pt.features.toSparse else pt.features.toDense
         requiredColumns(converter.toRow(Row(pt.label, features)))
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 478a83f..e164d27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -77,6 +77,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
   }
 
+  test("illegal vector types") {
+    val e = intercept[IllegalArgumentException] {
+      spark.read.format("libsvm").options(Map("VectorType" -> "sparser")).load(path)
+    }.getMessage
+    assert(e.contains("Invalid value `sparser` for parameter `vectorType`. Expected " +
+      "types are `sparse` and `dense`."))
+  }
+
   test("select a vector with specifying the longer dimension") {
     val df = spark.read.option("numFeatures", "100").format("libsvm")
       .load(path)
@@ -85,6 +93,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
   }
 
+  test("case insensitive option") {
+    val df = spark.read.option("NuMfEaTuReS", "100").format("libsvm").load(path)
+    assert(df.first().getAs[SparseVector](1) ==
+      Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
+  }
+
   test("write libsvm data and read it again") {
     val df = spark.read.format("libsvm").load(path)
     val tempDir2 = new File(tempDir, "read_write_test")

http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 6f6e301..d069044 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -65,9 +65,10 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
       dataSchema: StructType): OutputWriterFactory = {
     verifySchema(dataSchema)
 
+    val textOptions = new TextOptions(options)
     val conf = job.getConfiguration
-    val compressionCodec = options.get("compression").map(CompressionCodecs.getCodecClassName)
-    compressionCodec.foreach { codec =>
+
+    textOptions.compressionCodec.foreach { codec =>
       CompressionCodecs.setCodecConfiguration(conf, codec)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
new file mode 100644
index 0000000..8cad984
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.text
+
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
+
+/**
+ * Options for the Text data source.
+ */
+private[text] class TextOptions(@transient private val parameters: CaseInsensitiveMap)
+  extends Serializable {
+
+  import TextOptions._
+
+  def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
+
+  /**
+   * Compression codec to use.
+   */
+  val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
+}
+
+private[text] object TextOptions {
+  val COMPRESSION = "compression"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e33aaa2a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index d11c2ac..cb7393c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -115,8 +115,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
     )
     withTempDir { dir =>
       val testDf = spark.read.text(testFile)
-      val tempDir = Utils.createTempDir()
-      val tempDirPath = tempDir.getAbsolutePath
+      val tempDirPath = dir.getAbsolutePath
       testDf.write.option("compression", "none")
         .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath)
       val compressedFiles = new File(tempDirPath).listFiles()
@@ -125,6 +124,25 @@ class TextSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("case insensitive option") {
+    val extraOptions = Map[String, String](
+      "mApReDuCe.output.fileoutputformat.compress" -> "true",
+      "mApReDuCe.output.fileoutputformat.compress.type" -> CompressionType.BLOCK.toString,
+      "mApReDuCe.map.output.compress" -> "true",
+      "mApReDuCe.output.fileoutputformat.compress.codec" -> classOf[GzipCodec].getName,
+      "mApReDuCe.map.output.compress.codec" -> classOf[GzipCodec].getName
+    )
+    withTempDir { dir =>
+      val testDf = spark.read.text(testFile)
+      val tempDirPath = dir.getAbsolutePath
+      testDf.write.option("CoMpReSsIoN", "none")
+        .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath)
+      val compressedFiles = new File(tempDirPath).listFiles()
+      assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz")))
+      verifyFrame(spark.read.options(extraOptions).text(tempDirPath))
+    }
+  }
+
   test("SPARK-14343: select partitioning column") {
     withTempPath { dir =>
       val path = dir.getCanonicalPath


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org