You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/01/25 04:56:07 UTC

[spark] branch branch-3.1 updated: [SPARK-34133][AVRO] Respect case sensitivity when performing Catalyst-to-Avro field matching

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 7feb0ea  [SPARK-34133][AVRO] Respect case sensitivity when performing Catalyst-to-Avro field matching
7feb0ea is described below

commit 7feb0ea48b7838559017fb617c5b97aa7750e54c
Author: Erik Krogen <xk...@apache.org>
AuthorDate: Mon Jan 25 04:54:41 2021 +0000

    [SPARK-34133][AVRO] Respect case sensitivity when performing Catalyst-to-Avro field matching
    
    ### What changes were proposed in this pull request?
    Make the field name matching between Avro and Catalyst schemas, on both the reader and writer paths, respect the global SQL settings for case sensitivity (i.e. case-insensitive by default). `AvroSerializer` and `AvroDeserializer` share a common utility in `AvroUtils` to search for an Avro field to match a given Catalyst field.
    
    ### Why are the changes needed?
    Spark SQL is normally case-insensitive (by default), but currently when `AvroSerializer` and `AvroDeserializer` perform matching between Catalyst schemas and Avro schemas, the matching is done in a case-sensitive manner. So for example the following will fail:
    ```scala
          val avroSchema =
            """
              |{
              |  "type" : "record",
              |  "name" : "test_schema",
              |  "fields" : [
              |    {"name": "foo", "type": "int"},
              |    {"name": "BAR", "type": "int"}
              |  ]
              |}
          """.stripMargin
          val df = Seq((1, 3), (2, 4)).toDF("FOO", "bar")
    
          df.write.option("avroSchema", avroSchema).format("avro").save(savePath)
    ```
    
    The same is true on the read path, if we assume `testAvro` has been written using the schema above, the below will fail to match the fields:
    ```scala
    df.read.schema(new StructType().add("FOO", IntegerType).add("bar", IntegerType))
      .format("avro").load(testAvro)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    When reading Avro data, or writing Avro data using the `avroSchema` option, field matching will be performed with case sensitivity respecting the global SQL settings.
    
    ### How was this patch tested?
    New tests added to `AvroSuite` to validate the case sensitivity logic in an end-to-end manner through the SQL engine.
    
    Closes #31201 from xkrogen/xkrogen-SPARK-34133-avro-serde-casesensitivity-errormessages.
    
    Authored-by: Erik Krogen <xk...@apache.org>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 9371ea8c7bd87b87c4d3dfb4c830c65643e48f54)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../apache/spark/sql/avro/AvroDeserializer.scala   | 40 +++++-----
 .../org/apache/spark/sql/avro/AvroSerializer.scala |  8 +-
 .../org/apache/spark/sql/avro/AvroUtils.scala      | 31 ++++++++
 .../org/apache/spark/sql/avro/AvroSuite.scala      | 89 ++++++++++++++++++++++
 4 files changed, 145 insertions(+), 23 deletions(-)

diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 85416b8..c2c6f38 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -330,27 +330,29 @@ private[sql] class AvroDeserializer(
     var i = 0
     while (i < length) {
       val sqlField = sqlType.fields(i)
-      val avroField = avroType.getField(sqlField.name)
-      if (avroField != null) {
-        validFieldIndexes += avroField.pos()
-
-        val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
-        val ordinal = i
-        val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
-          if (value == null) {
-            fieldUpdater.setNullAt(ordinal)
-          } else {
-            baseWriter(fieldUpdater, ordinal, value)
+      AvroUtils.getAvroFieldByName(avroType, sqlField.name) match {
+        case Some(avroField) =>
+          validFieldIndexes += avroField.pos()
+
+          val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
+          val ordinal = i
+          val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+            if (value == null) {
+              fieldUpdater.setNullAt(ordinal)
+            } else {
+              baseWriter(fieldUpdater, ordinal, value)
+            }
           }
-        }
-        fieldWriters += fieldWriter
-      } else if (!sqlField.nullable) {
-        throw new IncompatibleSchemaException(
-          s"""
-             |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema.
-             |Source Avro schema: $rootAvroType.
-             |Target Catalyst type: $rootCatalystType.
+          fieldWriters += fieldWriter
+        case None if !sqlField.nullable =>
+          val fieldStr = s"${path.mkString(".")}.${sqlField.name}"
+          throw new IncompatibleSchemaException(
+            s"""
+               |Cannot find non-nullable field $fieldStr in Avro schema.
+               |Source Avro schema: $rootAvroType.
+               |Target Catalyst type: $rootCatalystType.
            """.stripMargin)
+        case _ => // nothing to do
       }
       i += 1
     }
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 33c6022..d716b10 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -230,10 +230,10 @@ private[sql] class AvroSerializer(
 
     val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
       catalystStruct.map { catalystField =>
-        val avroField = avroStruct.getField(catalystField.name)
-        if (avroField == null) {
-          throw new IncompatibleSchemaException(
-            s"Cannot convert Catalyst type $catalystStruct to Avro type $avroStruct.")
+        val avroField = AvroUtils.getAvroFieldByName(avroStruct, catalystField.name) match {
+          case Some(f) => f
+          case None => throw new IncompatibleSchemaException(
+            s"Cannot find ${catalystField.name} in Avro schema")
         }
         val converter = newConverter(catalystField.dataType, resolveNullableType(
           avroField.schema(), catalystField.nullable))
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 51997ac..c48b097 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -18,6 +18,8 @@ package org.apache.spark.sql.avro
 
 import java.io.{FileNotFoundException, IOException}
 
+import scala.collection.JavaConverters._
+
 import org.apache.avro.Schema
 import org.apache.avro.file.{DataFileReader, FileReader}
 import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC}
@@ -201,4 +203,33 @@ private[sql] object AvroUtils extends Logging {
       }
     }
   }
+
+  /**
+   * Extract a single field from `avroSchema` which has the desired field name,
+   * performing the matching with proper case sensitivity according to [[SQLConf.resolver]].
+   *
+   * @param avroSchema The schema in which to search for the field. Must be of type RECORD.
+   * @param name The name of the field to search for.
+   * @return `Some(match)` if a matching Avro field is found, otherwise `None`.
+   * @throws IncompatibleSchemaException if `avroSchema` is not a RECORD or contains multiple
+   *                                     fields matching `name` (i.e., case-insensitive matching
+   *                                     is used and `avroSchema` has two or more fields that have
+   *                                     the same name with difference case).
+   */
+  private[avro] def getAvroFieldByName(
+      avroSchema: Schema,
+      name: String): Option[Schema.Field] = {
+    if (avroSchema.getType != Schema.Type.RECORD) {
+      throw new IncompatibleSchemaException(
+        s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
+    }
+    avroSchema.getFields.asScala.filter(f => SQLConf.get.resolver(f.name(), name)).toSeq match {
+      case Seq(avroField) => Some(avroField)
+      case Seq() => None
+      case matches => throw new IncompatibleSchemaException(
+        s"Searching for '$name' in Avro schema gave ${matches.size} matches. Candidates: " +
+            matches.map(_.name()).mkString("[", ", ", "]")
+      )
+    }
+  }
 }
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index d3bfb71..17bdeda 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -37,6 +37,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
+import org.apache.spark.TestUtils.assertExceptionMsg
 import org.apache.spark.sql._
 import org.apache.spark.sql.TestingUDT.IntervalData
 import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters}
@@ -1261,6 +1262,94 @@ abstract class AvroSuite
     }
   }
 
+  test("SPARK-34133: Reading user provided schema respects case sensitivity for field matching") {
+    val wrongCaseSchema = new StructType()
+        .add("STRING", StringType, nullable = false)
+        .add("UNION_STRING_NULL", StringType, nullable = true)
+    val withSchema = spark.read
+        .schema(wrongCaseSchema)
+        .format("avro").load(testAvro).collect()
+
+    val withOutSchema = spark.read.format("avro").load(testAvro)
+        .select("STRING", "UNION_STRING_NULL")
+        .collect()
+    assert(withSchema.sameElements(withOutSchema))
+
+    withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
+      val  out = spark.read.format("avro").schema(wrongCaseSchema).load(testAvro).collect()
+      assert(out.forall(_.isNullAt(0)))
+      assert(out.forall(_.isNullAt(1)))
+    }
+  }
+
+  test("SPARK-34133: Writing user provided schema respects case sensitivity for field matching") {
+    withTempDir { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [
+          |    {"name": "foo", "type": "int"},
+          |    {"name": "BAR", "type": "int"}
+          |  ]
+          |}
+      """.stripMargin
+      val df = Seq((1, 3), (2, 4)).toDF("FOO", "bar")
+
+      val savePath = s"$tempDir/save"
+      df.write.option("avroSchema", avroSchema).format("avro").save(savePath)
+
+      val loaded = spark.read.format("avro").load(savePath)
+      assert(loaded.schema === new StructType().add("foo", IntegerType).add("BAR", IntegerType))
+      assert(loaded.collect().map(_.getInt(0)).toSet === Set(1, 2))
+      assert(loaded.collect().map(_.getInt(1)).toSet === Set(3, 4))
+
+      withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
+        val e = intercept[SparkException] {
+          df.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save2")
+        }
+        assertExceptionMsg(e, "Cannot find FOO in Avro schema")
+      }
+    }
+  }
+
+  test("SPARK-34133: Writing user provided schema with multiple matching Avro fields fails") {
+    withTempDir { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [
+          |    {"name": "foo", "type": "int"},
+          |    {"name": "FOO", "type": "string"}
+          |  ]
+          |}
+      """.stripMargin
+
+      val errorMsg = "Searching for 'foo' in Avro schema gave 2 matches. Candidates: [foo, FOO]"
+      assertExceptionMsg(intercept[SparkException] {
+        val fooBarDf = Seq((1, "3"), (2, "4")).toDF("foo", "bar")
+        fooBarDf.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save-fail")
+      }, errorMsg)
+
+      val savePath = s"$tempDir/save"
+      withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
+        val fooFooDf = Seq((1, "3"), (2, "4")).toDF("foo", "FOO")
+        fooFooDf.write.option("avroSchema", avroSchema).format("avro").save(savePath)
+
+        val loadedDf = spark.read.format("avro").schema(fooFooDf.schema).load(savePath)
+        assert(loadedDf.collect().toSet === fooFooDf.collect().toSet)
+      }
+
+      assertExceptionMsg(intercept[SparkException] {
+        val fooSchema = new StructType().add("foo", IntegerType)
+        spark.read.format("avro").schema(fooSchema).load(savePath).collect()
+      }, errorMsg)
+    }
+  }
+
   test("read avro with user defined schema: read partial columns") {
     val partialColumns = StructType(Seq(
       StructField("string", StringType, false),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org