You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/03/16 07:31:49 UTC

spark git commit: [SPARK-13899][SQL] Produce InternalRow instead of external Row at CSV data source

Repository: spark
Updated Branches:
  refs/heads/master 3c578c594 -> 92024797a


[SPARK-13899][SQL] Produce InternalRow instead of external Row at CSV data source

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13899

This PR makes CSV data source produce `InternalRow` instead of `Row`.

Basically, this resembles JSON data source. It uses the same codes for casting.

## How was this patch tested?

Unit tests were used within IDE and code style was checked by `./dev/run_tests`.

Author: hyukjinkwon <gu...@gmail.com>

Closes #11717 from HyukjinKwon/SPARK-13899.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92024797
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92024797
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92024797

Branch: refs/heads/master
Commit: 92024797a4fad594b5314f3f3be5c6be2434de8a
Parents: 3c578c5
Author: hyukjinkwon <gu...@gmail.com>
Authored: Tue Mar 15 23:31:46 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Mar 15 23:31:46 2016 -0700

----------------------------------------------------------------------
 .../datasources/csv/CSVInferSchema.scala         | 19 +++++++++++++------
 .../execution/datasources/csv/CSVRelation.scala  | 15 +++++++++++----
 .../datasources/csv/DefaultSource.scala          | 13 +++++++------
 .../datasources/csv/CSVTypeCastSuite.scala       | 17 +++++++++++------
 4 files changed, 42 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92024797/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index edead9b..797f740 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.execution.datasources.csv
 
 import java.math.BigDecimal
-import java.sql.{Date, Timestamp}
 import java.text.NumberFormat
 import java.util.Locale
 
@@ -27,7 +26,9 @@ import scala.util.Try
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 private[csv] object CSVInferSchema {
 
@@ -116,7 +117,7 @@ private[csv] object CSVInferSchema {
   }
 
   def tryParseTimestamp(field: String): DataType = {
-    if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
+    if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
       TimestampType
     } else {
       tryParseBoolean(field)
@@ -191,12 +192,18 @@ private[csv] object CSVTypeCast {
         case _: DoubleType => Try(datum.toDouble)
           .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
         case _: BooleanType => datum.toBoolean
-        case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
+        case dt: DecimalType =>
+          val value = new BigDecimal(datum.replaceAll(",", ""))
+          Decimal(value, dt.precision, dt.scale)
         // TODO(hossein): would be good to support other common timestamp formats
-        case _: TimestampType => Timestamp.valueOf(datum)
+        case _: TimestampType =>
+          // This one will lose microseconds parts.
+          // See https://issues.apache.org/jira/browse/SPARK-10681.
+          DateTimeUtils.stringToTime(datum).getTime  * 1000L
         // TODO(hossein): would be good to support other common date formats
-        case _: DateType => Date.valueOf(datum)
-        case _: StringType => datum
+        case _: DateType =>
+          DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+        case _: StringType => UTF8String.fromString(datum)
         case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/92024797/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index c96a508..eeb56f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -29,6 +29,7 @@ import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 
@@ -54,7 +55,7 @@ object CSVRelation extends Logging {
       requiredColumns: Array[String],
       inputs: Seq[FileStatus],
       sqlContext: SQLContext,
-      params: CSVOptions): RDD[Row] = {
+      params: CSVOptions): RDD[InternalRow] = {
 
     val schemaFields = schema.fields
     val requiredFields = StructType(requiredColumns.map(schema(_))).fields
@@ -71,8 +72,8 @@ object CSVRelation extends Logging {
     }.foreach {
       case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
     }
-    val rowArray = new Array[Any](safeRequiredIndices.length)
     val requiredSize = requiredFields.length
+    val row = new GenericMutableRow(requiredSize)
     tokenizedRDD.flatMap { tokens =>
       if (params.dropMalformed && schemaFields.length != tokens.length) {
         logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
@@ -94,14 +95,20 @@ object CSVRelation extends Logging {
           while (subIndex < safeRequiredIndices.length) {
             index = safeRequiredIndices(subIndex)
             val field = schemaFields(index)
-            rowArray(subIndex) = CSVTypeCast.castTo(
+            // It anyway needs to try to parse since it decides if this row is malformed
+            // or not after trying to cast in `DROPMALFORMED` mode even if the casted
+            // value is not stored in the row.
+            val value = CSVTypeCast.castTo(
               indexSafeTokens(index),
               field.dataType,
               field.nullable,
               params.nullValue)
+            if (subIndex < requiredSize) {
+              row(subIndex) = value
+            }
             subIndex = subIndex + 1
           }
-          Some(Row.fromSeq(rowArray.take(requiredSize)))
+          Some(row)
         } catch {
           case NonFatal(e) if params.dropMalformed =>
             logWarning("Parse exception. " +

http://git-wip-us.apache.org/repos/asf/spark/blob/92024797/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index a5f9426..54e4c1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.execution.datasources.CompressionCodecs
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -113,13 +113,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
     val pathsString = csvFiles.map(_.getPath.toUri.toString)
     val header = dataSchema.fields.map(_.name)
     val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString)
-    val external = CSVRelation.parseCsv(
+    val rows = CSVRelation.parseCsv(
       tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions)
 
-    // TODO: Generate InternalRow in parseCsv
-    val outputSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
-    val encoder = RowEncoder(outputSchema)
-    external.map(encoder.toRow)
+    val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
+    rows.mapPartitions { iterator =>
+      val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
+      iterator.map(unsafeProjection)
+    }
   }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92024797/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index c28a250..5702a1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -18,11 +18,12 @@
 package org.apache.spark.sql.execution.datasources.csv
 
 import java.math.BigDecimal
-import java.sql.{Date, Timestamp}
 import java.util.Locale
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 class CSVTypeCastSuite extends SparkFunSuite {
 
@@ -32,7 +33,9 @@ class CSVTypeCastSuite extends SparkFunSuite {
     val decimalType = new DecimalType()
 
     stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
-      assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))
+      val decimalValue = new BigDecimal(decimalVal.toString)
+      assert(CSVTypeCast.castTo(strVal, decimalType) ===
+        Decimal(decimalValue, decimalType.precision, decimalType.scale))
     }
   }
 
@@ -65,8 +68,8 @@ class CSVTypeCastSuite extends SparkFunSuite {
   }
 
   test("String type should always return the same as the input") {
-    assert(CSVTypeCast.castTo("", StringType, nullable = true) == "")
-    assert(CSVTypeCast.castTo("", StringType, nullable = false) == "")
+    assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString(""))
+    assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString(""))
   }
 
   test("Throws exception for empty string with non null type") {
@@ -85,8 +88,10 @@ class CSVTypeCastSuite extends SparkFunSuite {
     assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
     assert(CSVTypeCast.castTo("true", BooleanType) == true)
     val timestamp = "2015-01-01 00:00:00"
-    assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
-    assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
+    assert(CSVTypeCast.castTo(timestamp, TimestampType) ==
+      DateTimeUtils.stringToTime(timestamp).getTime  * 1000L)
+    assert(CSVTypeCast.castTo("2015-01-01", DateType) ==
+      DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
   }
 
   test("Float and Double Types are cast correctly with Locale") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org