You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hudi.apache.org by da...@apache.org on 2022/06/30 12:48:58 UTC

[hudi] branch master updated: [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeseria… (#5907)

This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 8547899a39 [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeseria… (#5907)
8547899a39 is described below

commit 8547899a39168c8399d32ee7ee22f35dfe3f7c84
Author: komao <ma...@gmail.com>
AuthorDate: Thu Jun 30 20:48:50 2022 +0800

    [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeseria… (#5907)
    
    * [HUDI-4285] add ByteBuffer#rewind after ByteBuffer#get in AvroDeserializer
    
    * add ut
    
    Co-authored-by: wangzixuan.wzxuan <wa...@bytedance.com>
---
 .../org/apache/hudi/TestAvroConversionUtils.scala  | 57 +++++++++++++++++++++-
 .../apache/spark/sql/avro/AvroDeserializer.scala   |  2 +
 .../apache/spark/sql/avro/AvroDeserializer.scala   |  2 +
 .../apache/spark/sql/avro/AvroDeserializer.scala   |  2 +
 4 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala
index bacd44753d..16df1f869c 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala
@@ -18,8 +18,13 @@
 
 package org.apache.hudi
 
+import java.nio.ByteBuffer
+import java.util.Objects
 import org.apache.avro.Schema
-import org.apache.spark.sql.types.{DataTypes, StructType, StringType, ArrayType}
+import org.apache.avro.generic.GenericData
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, MapType, StringType, StructField, StructType}
 import org.scalatest.{FunSuite, Matchers}
 
 class TestAvroConversionUtils extends FunSuite with Matchers {
@@ -377,4 +382,54 @@ class TestAvroConversionUtils extends FunSuite with Matchers {
 
     assert(avroSchema.equals(expectedAvroSchema))
   }
+
+  test("test converter with binary") {
+    val avroSchema = new Schema.Parser().parse("{\"type\":\"record\",\"name\":\"h0_record\",\"namespace\":\"hoodie.h0\",\"fields\""
+      + ":[{\"name\":\"col9\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}")
+    val sparkSchema = StructType(List(StructField("col9", BinaryType, nullable = true)))
+    // create a test record with avroSchema
+    val avroRecord = new GenericData.Record(avroSchema)
+    val bb = ByteBuffer.wrap(Array[Byte](97, 48, 53))
+    avroRecord.put("col9", bb)
+    val row1 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
+    val row2 = AvroConversionUtils.createAvroToInternalRowConverter(avroSchema, sparkSchema).apply(avroRecord).get
+    internalRowCompare(row1, row2, sparkSchema)
+  }
+
+  private def internalRowCompare(expected: Any, actual: Any, schema: DataType): Unit = {
+    schema match {
+      case StructType(fields) =>
+        val expectedRow = expected.asInstanceOf[InternalRow]
+        val actualRow = actual.asInstanceOf[InternalRow]
+        fields.zipWithIndex.foreach { case (field, i) => internalRowCompare(expectedRow.get(i, field.dataType), actualRow.get(i, field.dataType), field.dataType) }
+      case ArrayType(elementType, _) =>
+        val expectedArray = expected.asInstanceOf[ArrayData].toSeq[Any](elementType)
+        val actualArray = actual.asInstanceOf[ArrayData].toSeq[Any](elementType)
+        if (expectedArray.size != actualArray.size) {
+          throw new AssertionError()
+        } else {
+          expectedArray.zip(actualArray).foreach { case (e1, e2) => internalRowCompare(e1, e2, elementType) }
+        }
+      case MapType(keyType, valueType, _) =>
+        val expectedKeyArray = expected.asInstanceOf[MapData].keyArray()
+        val expectedValueArray = expected.asInstanceOf[MapData].valueArray()
+        val actualKeyArray = actual.asInstanceOf[MapData].keyArray()
+        val actualValueArray = actual.asInstanceOf[MapData].valueArray()
+        internalRowCompare(expectedKeyArray, actualKeyArray, ArrayType(keyType))
+        internalRowCompare(expectedValueArray, actualValueArray, ArrayType(valueType))
+      case StringType => if (checkNull(expected, actual) || !expected.toString.equals(actual.toString)) {
+        throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
+      }
+      case BinaryType => if (checkNull(expected, actual) || !expected.asInstanceOf[Array[Byte]].sameElements(actual.asInstanceOf[Array[Byte]])) {
+        throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
+      }
+      case _ => if (!Objects.equals(expected, actual)) {
+        throw new AssertionError(String.format("%s is not equals %s", expected.toString, actual.toString))
+      }
+    }
+  }
+
+  private def checkNull(left: Any, right: Any): Boolean = {
+    (left == null && right != null) || (left == null && right != null)
+  }
 }
diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 2e0946f1eb..385577dd30 100644
--- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -146,6 +146,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
           case b: ByteBuffer =>
             val bytes = new Array[Byte](b.remaining)
             b.get(bytes)
+            // Do not forget to reset the position
+            b.rewind()
             bytes
           case b: Array[Byte] => b
           case other => throw new RuntimeException(s"$other is not a valid avro binary.")
diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 717df0f407..5fb6d907bd 100644
--- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -167,6 +167,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
           case b: ByteBuffer =>
             val bytes = new Array[Byte](b.remaining)
             b.get(bytes)
+            // Do not forget to reset the position
+            b.rewind()
             bytes
           case b: Array[Byte] => b
           case other => throw new RuntimeException(s"$other is not a valid avro binary.")
diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index ef9b590920..0b60933075 100644
--- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -181,6 +181,8 @@ private[sql] class AvroDeserializer(rootAvroType: Schema,
           case b: ByteBuffer =>
             val bytes = new Array[Byte](b.remaining)
             b.get(bytes)
+            // Do not forget to reset the position
+            b.rewind()
             bytes
           case b: Array[Byte] => b
           case other =>