You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by et...@apache.org on 2023/04/26 07:13:36 UTC

[iceberg] 01/03: Spark: Add read/write support for UUIDs

This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch spark-uuid-read-write-support-3.4
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 07ddc3c777b7dccda24098a9918e3746a27bc500
Author: Eduard Tudenhoefner <et...@gmail.com>
AuthorDate: Fri Apr 21 09:18:32 2023 +0200

    Spark: Add read/write support for UUIDs
---
 .../apache/iceberg/spark/data/SparkAvroWriter.java |  2 +-
 .../apache/iceberg/spark/data/SparkOrcReader.java  |  3 ++
 .../iceberg/spark/data/SparkOrcValueReaders.java   | 32 +++++++++++++++++++
 .../iceberg/spark/data/SparkOrcValueWriters.java   | 17 ++++++++++
 .../apache/iceberg/spark/data/SparkOrcWriter.java  | 11 ++++++-
 .../iceberg/spark/data/SparkParquetReaders.java    | 17 ++++++++++
 .../iceberg/spark/data/SparkParquetWriters.java    | 36 ++++++++++++++++++++++
 .../data/vectorized/VectorizedSparkOrcReaders.java |  5 ++-
 .../apache/iceberg/spark/data/AvroDataTest.java    |  2 +-
 .../org/apache/iceberg/spark/data/RandomData.java  |  2 ++
 .../iceberg/spark/data/TestSparkParquetWriter.java |  2 +-
 11 files changed, 124 insertions(+), 5 deletions(-)

diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java
index 15465568c2..04dfd46a18 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java
@@ -126,7 +126,7 @@ public class SparkAvroWriter implements MetricsAwareDatumWriter<InternalRow> {
             return SparkValueWriters.decimal(decimal.getPrecision(), decimal.getScale());
 
           case "uuid":
-            return ValueWriters.uuids();
+            return SparkValueWriters.uuids();
 
           default:
             throw new IllegalArgumentException("Unsupported logical type: " + logicalType);
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
index 78db137054..c20be44f67 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
@@ -123,6 +123,9 @@ public class SparkOrcReader implements OrcRowReader<InternalRow> {
         case STRING:
           return SparkOrcValueReaders.utf8String();
         case BINARY:
+          if (Type.TypeID.UUID == iPrimitive.typeId()) {
+            return SparkOrcValueReaders.uuids();
+          }
           return OrcValueReaders.bytes();
         default:
           throw new IllegalArgumentException("Unhandled type " + primitive);
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
index 9e9b3e53bb..2bc5ef96a3 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
@@ -19,6 +19,8 @@
 package org.apache.iceberg.spark.data;
 
 import java.math.BigDecimal;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.List;
 import java.util.Map;
 import org.apache.iceberg.orc.OrcValueReader;
@@ -26,6 +28,7 @@ import org.apache.iceberg.orc.OrcValueReaders;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.types.Types;
+import org.apache.iceberg.util.UUIDUtil;
 import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
 import org.apache.orc.storage.ql.exec.vector.ColumnVector;
 import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector;
@@ -49,6 +52,10 @@ public class SparkOrcValueReaders {
     return StringReader.INSTANCE;
   }
 
+  public static OrcValueReader<UTF8String> uuids() {
+    return UUIDReader.INSTANCE;
+  }
+
   public static OrcValueReader<Long> timestampTzs() {
     return TimestampTzReader.INSTANCE;
   }
@@ -170,6 +177,31 @@ public class SparkOrcValueReaders {
     }
   }
 
+  private static class UUIDReader implements OrcValueReader<UTF8String> {
+    private static final ThreadLocal<ByteBuffer> BUFFER =
+        ThreadLocal.withInitial(
+            () -> {
+              ByteBuffer buffer = ByteBuffer.allocate(16);
+              buffer.order(ByteOrder.BIG_ENDIAN);
+              return buffer;
+            });
+
+    private static final UUIDReader INSTANCE = new UUIDReader();
+
+    private UUIDReader() {}
+
+    @Override
+    public UTF8String nonNullRead(ColumnVector vector, int row) {
+      BytesColumnVector bytesVector = (BytesColumnVector) vector;
+      ByteBuffer buffer = BUFFER.get();
+      buffer.rewind();
+      buffer.put(bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]);
+      buffer.rewind();
+
+      return UTF8String.fromString(UUIDUtil.convert(buffer).toString());
+    }
+  }
+
   private static class TimestampTzReader implements OrcValueReader<Long> {
     private static final TimestampTzReader INSTANCE = new TimestampTzReader();
 
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java
index 780090f991..9a4f1b5b48 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java
@@ -18,10 +18,13 @@
  */
 package org.apache.iceberg.spark.data;
 
+import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.UUID;
 import java.util.stream.Stream;
 import org.apache.iceberg.FieldMetrics;
 import org.apache.iceberg.orc.OrcValueWriter;
+import org.apache.iceberg.util.UUIDUtil;
 import org.apache.orc.TypeDescription;
 import org.apache.orc.storage.common.type.HiveDecimal;
 import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
@@ -42,6 +45,10 @@ class SparkOrcValueWriters {
     return StringWriter.INSTANCE;
   }
 
+  static OrcValueWriter<?> uuids() {
+    return UUIDWriter.INSTANCE;
+  }
+
   static OrcValueWriter<?> timestampTz() {
     return TimestampTzWriter.INSTANCE;
   }
@@ -73,6 +80,16 @@ class SparkOrcValueWriters {
     }
   }
 
+  private static class UUIDWriter implements OrcValueWriter<UTF8String> {
+    private static final UUIDWriter INSTANCE = new UUIDWriter();
+
+    @Override
+    public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) {
+      ByteBuffer buffer = UUIDUtil.convertToByteBuffer(UUID.fromString(data.toString()));
+      ((BytesColumnVector) output).setRef(rowId, buffer.array(), 0, buffer.array().length);
+    }
+  }
+
   private static class TimestampTzWriter implements OrcValueWriter<Long> {
     private static final TimestampTzWriter INSTANCE = new TimestampTzWriter();
 
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java
index 60868b8700..c5477fac08 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java
@@ -111,6 +111,9 @@ public class SparkOrcWriter implements OrcRowWriter<InternalRow> {
         case DOUBLE:
           return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive));
         case BINARY:
+          if (Type.TypeID.UUID == iPrimitive.typeId()) {
+            return SparkOrcValueWriters.uuids();
+          }
           return GenericOrcWriters.byteArrays();
         case STRING:
         case CHAR:
@@ -173,7 +176,13 @@ public class SparkOrcWriter implements OrcRowWriter<InternalRow> {
         fieldGetter = SpecializedGetters::getDouble;
         break;
       case BINARY:
-        fieldGetter = SpecializedGetters::getBinary;
+        if (ORCSchemaUtil.BinaryType.UUID
+            .toString()
+            .equals(fieldType.getAttributeValue(ORCSchemaUtil.ICEBERG_BINARY_TYPE_ATTRIBUTE))) {
+          fieldGetter = SpecializedGetters::getUTF8String;
+        } else {
+          fieldGetter = SpecializedGetters::getBinary;
+        }
         // getBinary always makes a copy, so we don't need to worry about it
         // being changed behind our back.
         break;
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java
index 59f81de6ae..af16d9bbc2 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java
@@ -46,6 +46,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.types.Type.TypeID;
 import org.apache.iceberg.types.Types;
+import org.apache.iceberg.util.UUIDUtil;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.io.api.Binary;
 import org.apache.parquet.schema.GroupType;
@@ -232,6 +233,7 @@ public class SparkParquetReaders {
     }
 
     @Override
+    @SuppressWarnings("checkstyle:CyclomaticComplexity")
     public ParquetValueReader<?> primitive(
         org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) {
       ColumnDescriptor desc = type.getColumnDescription(currentPath());
@@ -282,6 +284,9 @@ public class SparkParquetReaders {
       switch (primitive.getPrimitiveTypeName()) {
         case FIXED_LEN_BYTE_ARRAY:
         case BINARY:
+          if (expected != null && expected.typeId() == TypeID.UUID) {
+            return new UUIDReader(desc);
+          }
           return new ParquetValueReaders.ByteArrayReader(desc);
         case INT32:
           if (expected != null && expected.typeId() == TypeID.LONG) {
@@ -413,6 +418,18 @@ public class SparkParquetReaders {
     }
   }
 
+  private static class UUIDReader extends PrimitiveReader<UTF8String> {
+    UUIDReader(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    @SuppressWarnings("ByteBufferBackingArray")
+    public UTF8String read(UTF8String ignored) {
+      return UTF8String.fromString(UUIDUtil.convert(column.nextBinary().toByteBuffer()).toString());
+    }
+  }
+
   private static class ArrayReader<E> extends RepeatedReader<ArrayData, ReusableArrayData, E> {
     private int readPos = 0;
     private int writePos = 0;
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
index 3637fa4a26..c1abec96cd 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java
@@ -18,10 +18,13 @@
  */
 package org.apache.iceberg.spark.data;
 
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
+import java.util.UUID;
 import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry;
 import org.apache.iceberg.parquet.ParquetValueWriter;
 import org.apache.iceberg.parquet.ParquetValueWriters;
@@ -35,6 +38,7 @@ import org.apache.iceberg.util.DecimalUtil;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.io.api.Binary;
 import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
 import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
 import org.apache.parquet.schema.MessageType;
 import org.apache.parquet.schema.PrimitiveType;
@@ -176,6 +180,9 @@ public class SparkParquetWriters {
       switch (primitive.getPrimitiveTypeName()) {
         case FIXED_LEN_BYTE_ARRAY:
         case BINARY:
+          if (LogicalTypeAnnotation.uuidType().equals(primitive.getLogicalTypeAnnotation())) {
+            return uuids(desc);
+          }
           return byteArrays(desc);
         case BOOLEAN:
           return ParquetValueWriters.booleans(desc);
@@ -316,6 +323,35 @@ public class SparkParquetWriters {
     }
   }
 
+  private static PrimitiveWriter<UTF8String> uuids(ColumnDescriptor desc) {
+    return new UUIDWriter(desc);
+  }
+
+  private static class UUIDWriter extends PrimitiveWriter<UTF8String> {
+    private static final ThreadLocal<ByteBuffer> BUFFER =
+        ThreadLocal.withInitial(
+            () -> {
+              ByteBuffer buffer = ByteBuffer.allocate(16);
+              buffer.order(ByteOrder.BIG_ENDIAN);
+              return buffer;
+            });
+
+    private UUIDWriter(ColumnDescriptor desc) {
+      super(desc);
+    }
+
+    @Override
+    public void write(int repetitionLevel, UTF8String string) {
+      UUID uuid = UUID.fromString(string.toString());
+      ByteBuffer buffer = BUFFER.get();
+      buffer.rewind();
+      buffer.putLong(uuid.getMostSignificantBits());
+      buffer.putLong(uuid.getLeastSignificantBits());
+      buffer.rewind();
+      column.writeBinary(repetitionLevel, Binary.fromReusedByteBuffer(buffer));
+    }
+  }
+
   private static class ByteArrayWriter extends PrimitiveWriter<byte[]> {
     private ByteArrayWriter(ColumnDescriptor desc) {
       super(desc);
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
index b2d8bd14be..c030311232 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
@@ -155,7 +155,10 @@ public class VectorizedSparkOrcReaders {
           primitiveValueReader = SparkOrcValueReaders.utf8String();
           break;
         case BINARY:
-          primitiveValueReader = OrcValueReaders.bytes();
+          primitiveValueReader =
+              Type.TypeID.UUID == iPrimitive.typeId()
+                  ? SparkOrcValueReaders.uuids()
+                  : OrcValueReaders.bytes();
           break;
         default:
           throw new IllegalArgumentException("Unhandled type " + primitive);
diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java
index 5fd137c536..db0d7336f1 100644
--- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java
+++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java
@@ -56,7 +56,7 @@ public abstract class AvroDataTest {
           optional(107, "date", Types.DateType.get()),
           required(108, "ts", Types.TimestampType.withZone()),
           required(110, "s", Types.StringType.get()),
-          // required(111, "uuid", Types.UUIDType.get()),
+          required(111, "uuid", Types.UUIDType.get()),
           required(112, "fixed", Types.FixedType.ofLength(7)),
           optional(113, "bytes", Types.BinaryType.get()),
           required(114, "dec_9_0", Types.DecimalType.of(9, 0)), // int encoded
diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java
index 1c95df8ced..478afcf09a 100644
--- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java
+++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java
@@ -329,6 +329,8 @@ public class RandomData {
           return UTF8String.fromString((String) obj);
         case DECIMAL:
           return Decimal.apply((BigDecimal) obj);
+        case UUID:
+          return UTF8String.fromString(UUID.nameUUIDFromBytes((byte[]) obj).toString());
         default:
           return obj;
       }
diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java
index 261fb8838a..467d8a27a2 100644
--- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java
+++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java
@@ -79,7 +79,7 @@ public class TestSparkParquetWriter {
                   Types.StringType.get(),
                   Types.StructType.of(
                       optional(22, "jumpy", Types.DoubleType.get()),
-                      required(23, "koala", Types.IntegerType.get()),
+                      required(23, "koala", Types.UUIDType.get()),
                       required(24, "couch rope", Types.IntegerType.get())))),
           optional(2, "slide", Types.StringType.get()));