You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2024/03/24 07:13:34 UTC

(spark) branch master updated: [SPARK-47497][SQL] Make `to_csv` support the output of `array/struct/map/binary` as pretty strings

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 18bb6a3aea82 [SPARK-47497][SQL] Make `to_csv` support the output of `array/struct/map/binary` as pretty strings
18bb6a3aea82 is described below

commit 18bb6a3aea826c2e279457ab72ce6656646cda69
Author: panbingkun <pa...@baidu.com>
AuthorDate: Sun Mar 24 00:13:24 2024 -0700

    [SPARK-47497][SQL] Make `to_csv` support the output of `array/struct/map/binary` as pretty strings
    
    ### What changes were proposed in this pull request?
    The pr aims make `to_csv`
    - support the output of `array/struct/map/binary` as `pretty strings`.
    - not support `variant`.
    
    ### Why are the changes needed?
    This PR was generated from follow-up comment suggestions https://github.com/apache/spark/pull/44665#issuecomment-2011239475,
    <img width="909" alt="image" src="https://github.com/apache/spark/assets/15246973/04dd1497-da42-4b03-b21d-b041ead86f87">
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    - Update existed UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45657 from panbingkun/SPARK-47497.
    
    Authored-by: panbingkun <pa...@baidu.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/functions/builtin.py            |  12 +-
 .../sql/catalyst/csv/UnivocityGenerator.scala      | 126 ++++++++++++++--
 .../sql/catalyst/expressions/csvExpressions.scala  |  35 ++++-
 .../org/apache/spark/sql/CsvFunctionsSuite.scala   | 165 +++++++++++++++++++++
 4 files changed, 308 insertions(+), 30 deletions(-)

diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index a31465a77873..99a2375965c2 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -15591,12 +15591,12 @@ def to_csv(col: "ColumnOrName", options: Optional[Dict[str, str]] = None) -> Col
     >>> from pyspark.sql import Row, functions as sf
     >>> data = [(1, Row(age=2, name='Alice', scores=[100, 200, 300]))]
     >>> df = spark.createDataFrame(data, ("key", "value"))
-    >>> df.select(sf.to_csv(df.value)).show(truncate=False) # doctest: +SKIP
-    +-----------------------+
-    |to_csv(value)          |
-    +-----------------------+
-    |2,Alice,"[100,200,300]"|
-    +-----------------------+
+    >>> df.select(sf.to_csv(df.value)).show(truncate=False)
+    +-------------------------+
+    |to_csv(value)            |
+    +-------------------------+
+    |2,Alice,"[100, 200, 300]"|
+    +-------------------------+
 
     Example 3: Converting a StructType with null values to a CSV string
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
index b61652f4b523..f10a53bde5dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
@@ -22,7 +22,8 @@ import java.io.Writer
 import com.univocity.parsers.csv.CsvWriter
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, SparkStringUtils, TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -36,9 +37,9 @@ class UnivocityGenerator(
   writerSettings.setHeaders(schema.fieldNames: _*)
   private val gen = new CsvWriter(writer, writerSettings)
 
-  // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
-  // When the value is null, this converter should not be called.
-  private type ValueConverter = (InternalRow, Int) => String
+  // A `ValueConverter` is responsible for converting a value of an `SpecializedGetters`
+  // to `String`. When the value is null, this converter should not be called.
+  private type ValueConverter = (SpecializedGetters, Int) => String
 
   // `ValueConverter`s for all values in the fields of the schema
   private val valueConverters: Array[ValueConverter] =
@@ -64,33 +65,126 @@ class UnivocityGenerator(
   private val nullAsQuotedEmptyString =
     SQLConf.get.getConf(SQLConf.LEGACY_NULL_VALUE_WRITTEN_AS_QUOTED_EMPTY_STRING_CSV)
 
-  @scala.annotation.tailrec
   private def makeConverter(dataType: DataType): ValueConverter = dataType match {
+    case BinaryType =>
+      (getter, ordinal) => SparkStringUtils.getHexString(getter.getBinary(ordinal))
+
     case DateType =>
-      (row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal))
+      (getter, ordinal) => dateFormatter.format(getter.getInt(ordinal))
 
     case TimestampType =>
-      (row: InternalRow, ordinal: Int) => timestampFormatter.format(row.getLong(ordinal))
+      (getter, ordinal) => timestampFormatter.format(getter.getLong(ordinal))
 
     case TimestampNTZType =>
-      (row: InternalRow, ordinal: Int) =>
-        timestampNTZFormatter.format(DateTimeUtils.microsToLocalDateTime(row.getLong(ordinal)))
+      (getter, ordinal) =>
+        timestampNTZFormatter.format(DateTimeUtils.microsToLocalDateTime(getter.getLong(ordinal)))
 
     case YearMonthIntervalType(start, end) =>
-      (row: InternalRow, ordinal: Int) =>
+      (getter, ordinal) =>
         IntervalUtils.toYearMonthIntervalString(
-          row.getInt(ordinal), IntervalStringStyles.ANSI_STYLE, start, end)
+          getter.getInt(ordinal), IntervalStringStyles.ANSI_STYLE, start, end)
 
     case DayTimeIntervalType(start, end) =>
-      (row: InternalRow, ordinal: Int) =>
-      IntervalUtils.toDayTimeIntervalString(
-        row.getLong(ordinal), IntervalStringStyles.ANSI_STYLE, start, end)
+      (getter, ordinal) =>
+        IntervalUtils.toDayTimeIntervalString(
+          getter.getLong(ordinal), IntervalStringStyles.ANSI_STYLE, start, end)
 
     case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
 
+    case ArrayType(et, _) =>
+      val elementConverter = makeConverter(et)
+      (getter, ordinal) =>
+        val array = getter.getArray(ordinal)
+        val builder = new StringBuilder
+        builder.append("[")
+        if (array.numElements() > 0) {
+          if (array.isNullAt(0)) {
+            appendNull(builder)
+          } else {
+            builder.append(elementConverter(array, 0))
+          }
+          var i = 1
+          while (i < array.numElements()) {
+            builder.append(",")
+            if (array.isNullAt(i)) {
+              appendNull(builder)
+            } else {
+              builder.append(" " + elementConverter(array, i))
+            }
+            i += 1
+          }
+        }
+        builder.append("]")
+        builder.toString()
+
+    case MapType(kt, vt, _) =>
+      val keyConverter = makeConverter(kt)
+      val valueConverter = makeConverter(vt)
+      (getter, ordinal) =>
+        val map = getter.getMap(ordinal)
+        val builder = new StringBuilder
+        builder.append("{")
+        if (map.numElements() > 0) {
+          val keyArray = map.keyArray()
+          val valueArray = map.valueArray()
+          builder.append(keyConverter(keyArray, 0))
+          builder.append(" ->")
+          if (valueArray.isNullAt(0)) {
+            appendNull(builder)
+          } else {
+            builder.append(" " + valueConverter(valueArray, 0))
+          }
+          var i = 1
+          while (i < map.numElements()) {
+            builder.append(", ")
+            builder.append(keyConverter(keyArray, i))
+            builder.append(" ->")
+            if (valueArray.isNullAt(i)) {
+              appendNull(builder)
+            } else {
+              builder.append(" " + valueConverter(valueArray, i))
+            }
+            i += 1
+          }
+        }
+        builder.append("}")
+        builder.toString()
+
+    case StructType(fields) =>
+      val converters = fields.map(_.dataType).map(makeConverter)
+      (getter, ordinal) =>
+        val row = getter.getStruct(ordinal, fields.length)
+        val builder = new StringBuilder
+        builder.append("{")
+        if (row.numFields > 0) {
+          if (row.isNullAt(0)) {
+            appendNull(builder)
+          } else {
+            builder.append(converters(0)(row, 0))
+          }
+          var i = 1
+          while (i < row.numFields) {
+            builder.append(",")
+            if (row.isNullAt(i)) {
+              appendNull(builder)
+            } else {
+              builder.append(" " + converters(i)(row, i))
+            }
+            i += 1
+          }
+        }
+        builder.append("}")
+        builder.toString()
+
     case dt: DataType =>
-      (row: InternalRow, ordinal: Int) =>
-        row.get(ordinal, dt).toString
+      (getter, ordinal) => getter.get(ordinal, dt).toString
+  }
+
+  private def appendNull(builder: StringBuilder): Unit = {
+    options.parameters.get(CSVOptions.NULL_VALUE) match {
+      case Some(_) => builder.append(" " + options.nullValue)
+      case None =>
+    }
   }
 
   private def convertRow(row: InternalRow): Seq[String] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index a2d17617a10f..4714fc1ded9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -21,13 +21,14 @@ import java.io.CharArrayWriter
 
 import com.univocity.parsers.csv.CsvParser
 
-import org.apache.spark.{SparkException, SparkIllegalArgumentException}
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.csv._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.util.TypeUtils._
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -260,16 +261,34 @@ case class StructsToCsv(
       child = child,
       timeZoneId = None)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    child.dataType match {
+      case schema: StructType if schema.map(_.dataType).forall(
+        dt => isSupportedDataType(dt)) => TypeCheckSuccess
+      case _ => DataTypeMismatch(
+        errorSubClass = "UNSUPPORTED_INPUT_TYPE",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "dataType" -> toSQLType(child.dataType)
+        )
+      )
+    }
+  }
+
+  private def isSupportedDataType(dataType: DataType): Boolean = dataType match {
+    case _: VariantType => false
+    case array: ArrayType => isSupportedDataType(array.elementType)
+    case map: MapType => isSupportedDataType(map.keyType) && isSupportedDataType(map.valueType)
+    case st: StructType => st.map(_.dataType).forall(dt => isSupportedDataType(dt))
+    case udt: UserDefinedType[_] => isSupportedDataType(udt.sqlType)
+    case _ => true
+  }
+
   @transient
   lazy val writer = new CharArrayWriter()
 
   @transient
-  lazy val inputSchema: StructType = child.dataType match {
-    case st: StructType => st
-    case other => throw new SparkIllegalArgumentException(
-      errorClass = "_LEGACY_ERROR_TEMP_3234",
-      messageParameters = Map("other" -> other.catalogString))
-  }
+  lazy val inputSchema: StructType = child.dataType.asInstanceOf[StructType]
 
   @transient
   lazy val gen = new UnivocityGenerator(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
index 5c6891f6a7f5..196a1fd38837 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import java.nio.charset.StandardCharsets
 import java.text.SimpleDateFormat
 import java.time.{Duration, LocalDateTime, Period}
 import java.util.Locale
@@ -31,6 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
 import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
+import org.apache.spark.unsafe.types._
 
 class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -603,4 +605,167 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
       $"csv", schema_of_csv("1,2\n2"), Map.empty[String, String].asJava))
     checkAnswer(actual, Row(Row(1, "2\n2")))
   }
+
+  test("SPARK-47497: null value display when w or w/o options (nullValue)") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice", null, "y")))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("x", StringType),
+      StructField("y", StringType)))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+    val actual1 = df.select(to_csv($"value"))
+    checkAnswer(actual1, Row("2,Alice,,y"))
+
+    val options = Map("nullValue" -> "-")
+    val actual2 = df.select(to_csv($"value", options.asJava))
+    checkAnswer(actual2, Row("2,Alice,-,y"))
+  }
+
+  test("SPARK-47497: to_csv support the data of ArrayType as pretty strings") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice", Array(100L, 200L, null, 300L))))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("scores", ArrayType(LongType))))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+    val actual1 = df.select(to_csv($"value"))
+    checkAnswer(actual1, Row("2,Alice,\"[100, 200,, 300]\""))
+
+    val options = Map("nullValue" -> "-")
+    val actual2 = df.select(to_csv($"value", options.asJava))
+    checkAnswer(actual2, Row("2,Alice,\"[100, 200, -, 300]\""))
+  }
+
+  test("SPARK-47497: to_csv support the data of MapType as pretty strings") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice",
+      Map("math" -> 100L, "english" -> 200L, "science" -> null))))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("scores", MapType(StringType, LongType))))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+    val actual1 = df.select(to_csv($"value"))
+    checkAnswer(actual1, Row("2,Alice,\"{math -> 100, english -> 200, science ->}\""))
+
+    val options = Map("nullValue" -> "-")
+    val actual2 = df.select(to_csv($"value", options.asJava))
+    checkAnswer(actual2, Row("2,Alice,\"{math -> 100, english -> 200, science -> -}\""))
+  }
+
+  test("SPARK-47497: to_csv support the data of StructType as pretty strings") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice", Row(100L, 200L, null))))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("scores", StructType(Seq(
+        StructField("id1", LongType),
+        StructField("id2", LongType),
+        StructField("id3", LongType))))))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+    val actual1 = df.select(to_csv($"value"))
+    checkAnswer(actual1, Row("2,Alice,\"{100, 200,}\""))
+
+    val options = Map("nullValue" -> "-")
+    val actual2 = df.select(to_csv($"value", options.asJava))
+    checkAnswer(actual2, Row("2,Alice,\"{100, 200, -}\""))
+  }
+
+  test("SPARK-47497: to_csv support the data of BinaryType as pretty strings") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice", "a".getBytes(StandardCharsets.UTF_8))))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("a", BinaryType)))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+    val actual = df.select(to_csv($"value"))
+    checkAnswer(actual, Row("2,Alice,[61]"))
+  }
+
+  test("SPARK-47497: to_csv can display NullType data") {
+    val df = Seq(Tuple1(Tuple1(null))).toDF("value")
+    val options = Map("nullValue" -> "-")
+    val actual = df.select(to_csv($"value", options.asJava))
+    checkAnswer(actual, Row("-"))
+  }
+
+  test("SPARK-47497: from_csv/to_csv does not support VariantType data") {
+    val rows = new java.util.ArrayList[Row]()
+    rows.add(Row(1L, Row(2L, "Alice", new VariantVal(Array[Byte](1, 2, 3), Array[Byte](4, 5)))))
+
+    val valueSchema = StructType(Seq(
+      StructField("age", LongType),
+      StructField("name", StringType),
+      StructField("v", VariantType)))
+    val schema = StructType(Seq(
+      StructField("key", LongType),
+      StructField("value", valueSchema)))
+
+    val df = spark.createDataFrame(rows, schema)
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(to_csv($"value")).collect()
+      },
+      errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE",
+      parameters = Map(
+        "functionName" -> "`to_csv`",
+        "dataType" -> "\"STRUCT<age: BIGINT, name: STRING, v: VARIANT>\"",
+        "sqlExpr" -> "\"to_csv(value)\""),
+      context = ExpectedContext(fragment = "to_csv", getCurrentClassCallSitePattern)
+    )
+
+    checkError(
+      exception = intercept[SparkUnsupportedOperationException] {
+        df.select(from_csv(lit("data"), valueSchema, Map.empty[String, String])).collect()
+      },
+      errorClass = "UNSUPPORTED_DATATYPE",
+      parameters = Map("typeName" -> "\"VARIANT\"")
+    )
+  }
+
+  test("SPARK-47497: the input of to_csv must be StructType") {
+    val df = Seq(1, 2).toDF("value")
+    checkError(
+      exception = intercept[AnalysisException] {
+        df.select(to_csv($"value")).collect()
+      },
+      errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE",
+      parameters = Map(
+        "functionName" -> "`to_csv`",
+        "dataType" -> "\"INT\"",
+        "sqlExpr" -> "\"to_csv(value)\""),
+      context = ExpectedContext(fragment = "to_csv", getCurrentClassCallSitePattern)
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org