You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/02/14 23:12:38 UTC

[spark] branch master updated: [SPARK-42406][PROTOBUF] Fix recursive depth setting for Protobuf functions

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 96516c8e312 [SPARK-42406][PROTOBUF] Fix recursive depth setting for Protobuf functions
96516c8e312 is described below

commit 96516c8e3129929518ae4c9877983868e75dc4a4
Author: Raghu Angadi <ra...@databricks.com>
AuthorDate: Tue Feb 14 15:12:19 2023 -0800

    [SPARK-42406][PROTOBUF] Fix recursive depth setting for Protobuf functions
    
    ### What changes were proposed in this pull request?
    
    This fixes how setting for limiting recursive depth is handled in Protobuf functions (`recursive.fields.max.depth`).
    Original documentation says as setting it to '0' removes the recursive field. But we never did that. We allow at least once. E.g. schema for recursive message 'EventPerson' does not change between the settings '0' or '1'.
    
    This fixes it by requiring the max depth to be at least 1. It also fixes how the recursion enfored.
    Updated the tests and added an extra test with new protobuf 'EventPersonWrapper'.
    
    I will annotate the diff inline pointing to main fixes.
    
    ### Why are the changes needed?
    
    This fixes a bug with enforcing `recursive.fields.max.depth` and clarifies more in the documentation.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Unit tests
    
    Closes #40011 from rangadi/recursive-depth.
    
    Authored-by: Raghu Angadi <ra...@databricks.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../spark/sql/protobuf/utils/ProtobufOptions.scala |  10 +-
 .../sql/protobuf/utils/SchemaConverters.scala      |  10 +-
 .../test/resources/protobuf/functions_suite.desc   | Bin 8739 -> 8836 bytes
 .../test/resources/protobuf/functions_suite.proto  |   6 +-
 .../sql/protobuf/ProtobufFunctionsSuite.scala      | 166 ++++++++++++---------
 5 files changed, 107 insertions(+), 85 deletions(-)

diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
index 52f9f74bd43..53036668ebf 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -39,12 +39,12 @@ private[sql] class ProtobufOptions(
   val parseMode: ParseMode =
     parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
 
-  // Setting the `recursive.fields.max.depth` to 0 drops all recursive fields,
-  // 1 allows it to be recurse once, and 2 allows it to be recursed twice and so on.
-  // A value of `recursive.fields.max.depth` greater than 10 is not permitted. If it is not
-  // specified, the default value is -1; recursive fields are not permitted. If a protobuf
+  // Setting the `recursive.fields.max.depth` to 1 allows it to be recurse once,
+  // and 2 allows it to be recursed twice and so on. A value of `recursive.fields.max.depth`
+  // greater than 10 is not permitted. If it is not  specified, the default value is -1;
+  // A value of 0 or below disallows any recursive fields. If a protobuf
   // record has more depth than the allowed value for recursive fields, it will be truncated
-  // and some fields may be discarded.
+  // and corresponding fields are ignored (dropped).
   val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt
 }
 
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index bb4aa492f5c..9666e34bab4 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -118,17 +118,17 @@ object SchemaConverters {
         // discarded.
         // SQL Schema for the protobuf message `message Person { string name = 1; Person bff = 2}`
         // will vary based on the value of "recursive.fields.max.depth".
-        // 0: struct<name: string, bff: null>
-        // 1: struct<name string, bff: <name: string, bff: null>>
-        // 2: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ...
+        // 1: struct<name: string, bff: null>
+        // 2: struct<name string, bff: <name: string, bff: null>>
+        // 3: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ...
         val recordName = fd.getMessageType.getFullName
         val recursiveDepth = existingRecordNames.getOrElse(recordName, 0)
         val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth
-        if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth < 0 ||
+        if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth <= 0 ||
           recursiveFieldMaxDepth > 10)) {
           throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
         } else if (existingRecordNames.contains(recordName) &&
-          recursiveDepth > recursiveFieldMaxDepth) {
+          recursiveDepth >= recursiveFieldMaxDepth) {
           Some(NullType)
         } else {
           val newRecordNames = existingRecordNames + (recordName -> (recursiveDepth + 1))
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc
index 135d489f520..d16f8935080 100644
Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
index 449f1b68bb8..a0698ee3979 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -224,11 +224,15 @@ message EM2 {
   Employee em2Manager = 2;
 }
 
-message EventPerson {
+message EventPerson { // Used for simple recursive field testing.
   string name = 1;
   EventPerson bff = 2;
 }
 
+message EventPersonWrapper {
+  EventPerson person = 1;
+}
+
 message OneOfEventWithRecursion {
   string key = 1;
   oneof payload {
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 30b38eafd78..60e13644fc6 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -25,7 +25,7 @@ import com.google.protobuf.{ByteString, DynamicMessage}
 
 import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
 import org.apache.spark.sql.functions.{lit, struct}
-import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
+import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
 import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import org.apache.spark.sql.test.SharedSparkSession
@@ -39,6 +39,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
   val testFileDesc = testFile("functions_suite.desc", "protobuf/functions_suite.desc")
   private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
 
+  private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")
+
   /**
    * Runs the given closure twice. Once with descriptor file and second time with Java class name.
    */
@@ -385,34 +387,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     }
   }
 
-  test("Handle recursive fields in Protobuf schema, A->B->A") {
-    val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA")
-    val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB")
-
-    val messageBForA = DynamicMessage
-      .newBuilder(schemaB)
-      .setField(schemaB.findFieldByName("keyB"), "key")
-      .build()
-
-    val messageA = DynamicMessage
-      .newBuilder(schemaA)
-      .setField(schemaA.findFieldByName("keyA"), "key")
-      .setField(schemaA.findFieldByName("messageB"), messageBForA)
-      .build()
-
-    val messageB = DynamicMessage
-      .newBuilder(schemaB)
-      .setField(schemaB.findFieldByName("keyB"), "key")
-      .setField(schemaB.findFieldByName("messageA"), messageA)
-      .build()
-
-    val df = Seq(messageB.toByteArray).toDF("messageB")
-
+  test("Recursive fields in Protobuf should result in an error (B -> A -> B)") {
     checkWithFileAndClassName("recursiveB") {
       case (name, descFilePathOpt) =>
         val e = intercept[AnalysisException] {
-          df.select(
-            from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto"))
+          emptyBinaryDF.select(
+            from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
             .show()
         }
         assert(e.getMessage.contains(
@@ -421,34 +401,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     }
   }
 
-  test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
-    val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC")
-    val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD")
-
-    val messageDForC = DynamicMessage
-      .newBuilder(schemaD)
-      .setField(schemaD.findFieldByName("keyD"), "key")
-      .build()
-
-    val messageC = DynamicMessage
-      .newBuilder(schemaC)
-      .setField(schemaC.findFieldByName("keyC"), "key")
-      .setField(schemaC.findFieldByName("messageD"), messageDForC)
-      .build()
-
-    val messageD = DynamicMessage
-      .newBuilder(schemaD)
-      .setField(schemaD.findFieldByName("keyD"), "key")
-      .addRepeatedField(schemaD.findFieldByName("messageC"), messageC)
-      .build()
-
-    val df = Seq(messageD.toByteArray).toDF("messageD")
-
+  test("Recursive fields in Protobuf should result in an error, C->D->Array(C)") {
     checkWithFileAndClassName("recursiveD") {
       case (name, descFilePathOpt) =>
         val e = intercept[AnalysisException] {
-          df.select(
-            from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto"))
+          emptyBinaryDF.select(
+            from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
             .show()
         }
         assert(e.getMessage.contains(
@@ -457,6 +415,22 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     }
   }
 
+  test("Setting depth to 0 or -1 should trigger error on recursive fields (B -> A -> B)") {
+    for (depth <- Seq("0", "-1")) {
+      val e = intercept[AnalysisException] {
+        emptyBinaryDF.select(
+          functions.from_protobuf(
+            $"binary", "recursiveB", testFileDesc,
+            Map("recursive.fields.max.depth" -> depth).asJava
+          ).as("messageFromProto")
+        ).show()
+      }
+      assert(e.getMessage.contains(
+        "Found recursive reference in Protobuf schema, which can not be processed by Spark"
+      ))
+    }
+  }
+
   test("Handle extra fields : oldProducer -> newConsumer") {
     val testFileDesc = testFile("catalyst_types.desc", "protobuf/catalyst_types.desc")
     val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer")
@@ -818,21 +792,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
   }
 
   test("Fail for recursion field with complex schema without recursive.fields.max.depth") {
-    val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build()
-    val aaEventWithRecursion = EventWithRecursion.newBuilder().setKey(3).build()
-    val aaaEventWithRecursion = EventWithRecursion.newBuilder().setKey(4).build()
-    val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092)
-    val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c)
-    val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build()
-    val eventWithRecursion = EventWithRecursion.newBuilder().setKey(1).setA(a).build()
-
-    val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent")
-
     checkWithFileAndClassName("EventWithRecursion") {
       case (name, descFilePathOpt) =>
         val e = intercept[AnalysisException] {
-          df.select(
-            from_protobuf_wrapper($"protoEvent", name, descFilePathOpt).as("messageFromProto"))
+          emptyBinaryDF.select(
+            from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
             .show()
         }
         assert(e.getMessage.contains(
@@ -853,7 +817,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
 
     val df = Seq(employee.toByteArray).toDF("protoEvent")
     val options = new java.util.HashMap[String, String]()
-    options.put("recursive.fields.max.depth", "1")
+    options.put("recursive.fields.max.depth", "2")
 
     val fromProtoDf = df.select(
       functions.from_protobuf($"protoEvent", "Employee", testFileDesc, options) as 'sample)
@@ -908,7 +872,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")
 
     val options = new java.util.HashMap[String, String]()
-    options.put("recursive.fields.max.depth", "1")
+    options.put("recursive.fields.max.depth", "2")
 
     val fromProtoDf = df.select(
       functions.from_protobuf($"value",
@@ -1128,7 +1092,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     })
   }
 
-  test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple Schema") {
+  test("Verify recursive.fields.max.depth Levels 1,2, and 3 with Simple Schema") {
     val eventPerson3 = EventPerson.newBuilder().setName("person3").build()
     val eventPerson2 = EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build()
     val eventPerson1 = EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build()
@@ -1136,7 +1100,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     val df = Seq(eventPerson0.toByteArray).toDF("value")
 
     val optionsZero = new java.util.HashMap[String, String]()
-    optionsZero.put("recursive.fields.max.depth", "0")
+    optionsZero.put("recursive.fields.max.depth", "1")
     val schemaZero = DataType.fromJson(
         s"""{
            |  "type" : "struct",
@@ -1160,10 +1124,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
     val expectedDfZero = spark.createDataFrame(
       spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero)
 
-    testFromProtobufWithOptions(df, expectedDfZero, optionsZero)
+    testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson")
 
     val optionsOne = new java.util.HashMap[String, String]()
-    optionsOne.put("recursive.fields.max.depth", "1")
+    optionsOne.put("recursive.fields.max.depth", "2")
     val schemaOne = DataType.fromJson(
       s"""{
          |  "type" : "struct",
@@ -1197,10 +1161,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
          |}""".stripMargin).asInstanceOf[StructType]
     val expectedDfOne = spark.createDataFrame(
       spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne)
-    testFromProtobufWithOptions(df, expectedDfOne, optionsOne)
+    testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson")
 
     val optionsTwo = new java.util.HashMap[String, String]()
-    optionsTwo.put("recursive.fields.max.depth", "2")
+    optionsTwo.put("recursive.fields.max.depth", "3")
     val schemaTwo = DataType.fromJson(
       s"""{
          |  "type" : "struct",
@@ -1245,7 +1209,60 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
          |}""".stripMargin).asInstanceOf[StructType]
     val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize(
       Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo)
-    testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo)
+    testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson")
+
+    // Test recursive level 1 with EventPersonWrapper. In this case the top level struct
+    // 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'.
+    // "bff" appears twice: Once allowed recursion and second time as terminated "null" type.
+    val wrapperSchemaOne = DataType.fromJson(
+      """
+        |{
+        |  "type" : "struct",
+        |  "fields" : [ {
+        |    "name" : "sample",
+        |    "type" : {
+        |      "type" : "struct",
+        |      "fields" : [ {
+        |        "name" : "person",
+        |        "type" : {
+        |          "type" : "struct",
+        |          "fields" : [ {
+        |            "name" : "name",
+        |            "type" : "string",
+        |            "nullable" : true
+        |          }, {
+        |            "name" : "bff",
+        |            "type" : {
+        |              "type" : "struct",
+        |              "fields" : [ {
+        |                "name" : "name",
+        |                "type" : "string",
+        |                "nullable" : true
+        |              }, {
+        |                "name" : "bff",
+        |                "type" : "void",
+        |                "nullable" : true
+        |              } ]
+        |            },
+        |            "nullable" : true
+        |          } ]
+        |        },
+        |        "nullable" : true
+        |      } ]
+        |    },
+        |    "nullable" : true
+        |  } ]
+        |}
+        |""".stripMargin).asInstanceOf[StructType]
+    val expectedWrapperDfOne = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", null)))))),
+      wrapperSchemaOne)
+    testFromProtobufWithOptions(
+      Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(),
+      expectedWrapperDfOne,
+      optionsOne,
+      "EventPersonWrapper"
+    )
   }
 
   test("Verify exceptions are correctly propagated with errors") {
@@ -1273,9 +1290,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
   def testFromProtobufWithOptions(
     df: DataFrame,
     expectedDf: DataFrame,
-    options: java.util.HashMap[String, String]): Unit = {
+    options: java.util.HashMap[String, String],
+    messageName: String): Unit = {
     val fromProtoDf = df.select(
-      functions.from_protobuf($"value", "EventPerson", testFileDesc, options) as 'sample)
+      functions.from_protobuf($"value", messageName, testFileDesc, options) as 'sample)
     assert(expectedDf.schema === fromProtoDf.schema)
     checkAnswer(fromProtoDf, expectedDf)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org