You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/07 04:24:09 UTC

spark git commit: [SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures

Repository: spark
Updated Branches:
  refs/heads/master e5bb26174 -> d451b7f43


[SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures

## What changes were proposed in this pull request?

This is related with [SPARK-19918](https://issues.apache.org/jira/browse/SPARK-19918) and [SPARK-18362](https://issues.apache.org/jira/browse/SPARK-18362).

This PR proposes to use `TextFileFormat` and allow multiple input paths (but with a warning) when determining the number of features in LibSVM data source via an extra scan.

There are three points here:

- The main advantage of this change should be to remove file-listing bottlenecks in driver side.

- Another advantage is ones from using `FileScanRDD`. For example, I guess we can use `spark.sql.files.ignoreCorruptFiles` option when determining the number of features.

- We can unify the schema inference code path in text based data sources. This is also a preparation for [SPARK-21289](https://issues.apache.org/jira/browse/SPARK-21289).

## How was this patch tested?

Unit tests in `LibSVMRelationSuite`.

Closes #18288

Author: hyukjinkwon <gu...@gmail.com>

Closes #18556 from HyukjinKwon/libsvm-schema.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d451b7f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d451b7f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d451b7f4

Branch: refs/heads/master
Commit: d451b7f43d559aa1efd7ac3d1cbec5249f3a7a24
Parents: e5bb261
Author: hyukjinkwon <gu...@gmail.com>
Authored: Fri Jul 7 12:24:03 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jul 7 12:24:03 2017 +0800

----------------------------------------------------------------------
 .../spark/ml/source/libsvm/LibSVMRelation.scala | 26 ++++++++++----------
 .../org/apache/spark/mllib/util/MLUtils.scala   | 25 +++++++++++++++++--
 .../ml/source/libsvm/LibSVMRelationSuite.scala  | 17 ++++++++++---
 3 files changed, 49 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index f68847a..dec1183 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.TaskContext
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
@@ -66,7 +67,10 @@ private[libsvm] class LibSVMOutputWriter(
 
 /** @see [[LibSVMDataSource]] for public documentation. */
 // If this is moved or renamed, please update DataSource's backwardCompatibilityMap.
-private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister {
+private[libsvm] class LibSVMFileFormat
+  extends TextBasedFileFormat
+  with DataSourceRegister
+  with Logging {
 
   override def shortName(): String = "libsvm"
 
@@ -89,18 +93,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
       files: Seq[FileStatus]): Option[StructType] = {
     val libSVMOptions = new LibSVMOptions(options)
     val numFeatures: Int = libSVMOptions.numFeatures.getOrElse {
-      // Infers number of features if the user doesn't specify (a valid) one.
-      val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
-      val path = if (dataFiles.length == 1) {
-        dataFiles.head.getPath.toUri.toString
-      } else if (dataFiles.isEmpty) {
-        throw new IOException("No input path specified for libsvm data")
-      } else {
-        throw new IOException("Multiple input paths are not supported for libsvm data.")
-      }
-
-      val sc = sparkSession.sparkContext
-      val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
+      require(files.nonEmpty, "No input path specified for libsvm data")
+      logWarning(
+        "'numFeatures' option not specified, determining the number of features by going " +
+        "though the input. If you know the number in advance, please specify it via " +
+        "'numFeatures' option to avoid the extra scan.")
+
+      val paths = files.map(_.getPath.toUri.toString)
+      val parsed = MLUtils.parseLibSVMFile(sparkSession, paths)
       MLUtils.computeNumFeatures(parsed)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 4fdad05..14af8b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -28,8 +28,10 @@ import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.linalg.BLAS.dot
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
-import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.functions._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.random.BernoulliCellSampler
 
@@ -102,6 +104,25 @@ object MLUtils extends Logging {
       .map(parseLibSVMRecord)
   }
 
+  private[spark] def parseLibSVMFile(
+      sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = {
+    val lines = sparkSession.baseRelationToDataFrame(
+      DataSource.apply(
+        sparkSession,
+        paths = paths,
+        className = classOf[TextFileFormat].getName
+      ).resolveRelation(checkFilesExist = false))
+      .select("value")
+
+    import lines.sqlContext.implicits._
+
+    lines.select(trim($"value").as("line"))
+      .filter(not((length($"line") === 0).or($"line".startsWith("#"))))
+      .as[String]
+      .rdd
+      .map(MLUtils.parseLibSVMRecord)
+  }
+
   private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
     val items = line.split(' ')
     val label = items.head.toDouble

http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index e164d27..a67e49d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    val lines =
+    val lines0 =
       """
         |1 1:1.0 3:2.0 5:3.0
         |0
+      """.stripMargin
+    val lines1 =
+      """
         |0 2:4.0 4:5.0 6:6.0
       """.stripMargin
     val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
-    val file = new File(dir, "part-00000")
-    Files.write(lines, file, StandardCharsets.UTF_8)
+    val succ = new File(dir, "_SUCCESS")
+    val file0 = new File(dir, "part-00000")
+    val file1 = new File(dir, "part-00001")
+    Files.write("", succ, StandardCharsets.UTF_8)
+    Files.write(lines0, file0, StandardCharsets.UTF_8)
+    Files.write(lines1, file1, StandardCharsets.UTF_8)
     path = dir.toURI.toString
   }
 
@@ -145,7 +152,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   test("create libsvmTable table without schema and path") {
     try {
-      val e = intercept[IOException](spark.sql("CREATE TABLE libsvmTable USING libsvm"))
+      val e = intercept[IllegalArgumentException] {
+        spark.sql("CREATE TABLE libsvmTable USING libsvm")
+      }
       assert(e.getMessage.contains("No input path specified for libsvm data"))
     } finally {
       spark.sql("DROP TABLE IF EXISTS libsvmTable")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org