You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2021/11/01 19:47:24 UTC

[iceberg] 04/06: Spark: Fix ClassCastException when using bucket UDF (#3368)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch 0.12.x
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 4ba2157620d1bbc19c968e69ad43849f3706ba16
Author: Chen Zhang <67...@users.noreply.github.com>
AuthorDate: Wed Oct 27 07:46:09 2021 +0800

    Spark: Fix ClassCastException when using bucket UDF (#3368)
---
 .../org/apache/iceberg/spark/IcebergSpark.java     |   3 +-
 .../apache/iceberg/spark/SparkValueConverter.java  |   3 +-
 .../iceberg/spark/source/TestIcebergSpark.java     | 132 +++++++++++++++++++--
 3 files changed, 129 insertions(+), 9 deletions(-)

diff --git a/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java b/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java
index ac659f6..862626d 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java
@@ -34,6 +34,7 @@ public class IcebergSpark {
     SparkTypeToType typeConverter = new SparkTypeToType();
     Type sourceIcebergType = typeConverter.atomic(sourceType);
     Transform<Object, Integer> bucket = Transforms.bucket(sourceIcebergType, numBuckets);
-    session.udf().register(funcName, bucket::apply, DataTypes.IntegerType);
+    session.udf().register(funcName,
+        value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType);
   }
 }
diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
index 92c812a..ef453c0 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
@@ -79,8 +79,9 @@ public class SparkValueConverter {
         return DateTimeUtils.fromJavaTimestamp((Timestamp) object);
       case BINARY:
         return ByteBuffer.wrap((byte[]) object);
-      case BOOLEAN:
       case INTEGER:
+        return ((Number) object).intValue();
+      case BOOLEAN:
       case LONG:
       case FLOAT:
       case DOUBLE:
diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java
index 14785d7..d85b114 100644
--- a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java
+++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java
@@ -19,13 +19,22 @@
 
 package org.apache.iceberg.spark.source;
 
+import java.math.BigDecimal;
+import java.nio.ByteBuffer;
+import java.sql.Date;
+import java.sql.Timestamp;
 import java.util.List;
 import org.apache.iceberg.spark.IcebergSpark;
 import org.apache.iceberg.transforms.Transforms;
 import org.apache.iceberg.types.Types;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.util.DateTimeUtils;
+import org.apache.spark.sql.types.CharType;
 import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.VarcharType;
+import org.assertj.core.api.Assertions;
 import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
@@ -48,23 +57,132 @@ public abstract class TestIcebergSpark {
   }
 
   @Test
-  public void testRegisterBucketUDF() {
+  public void testRegisterIntegerBucketUDF() {
     IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16);
     List<Row> results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList();
     Assert.assertEquals(1, results.size());
     Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
         results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterShortBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16);
+    List<Row> results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterByteBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16);
+    List<Row> results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
+        results.get(0).getInt(0));
+  }
 
+  @Test
+  public void testRegisterLongBucketUDF() {
     IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16);
-    List<Row> results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
-    Assert.assertEquals(1, results2.size());
+    List<Row> results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
+    Assert.assertEquals(1, results.size());
     Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L),
-        results2.get(0).getInt(0));
+        results.get(0).getInt(0));
+  }
 
+  @Test
+  public void testRegisterStringBucketUDF() {
     IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16);
-    List<Row> results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
-    Assert.assertEquals(1, results3.size());
+    List<Row> results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterCharBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16);
+    List<Row> results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterVarCharBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16);
+    List<Row> results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList();
+    Assert.assertEquals(1, results.size());
     Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
-        results3.get(0).getInt(0));
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterDateBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16);
+    List<Row> results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16)
+            .apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterTimestampBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16);
+    List<Row> results =
+        spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16)
+            .apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterBinaryBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16);
+    List<Row> results =
+        spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16)
+            .apply(ByteBuffer.wrap((new byte[]{0x00, 0x20, 0x00, 0x1F}))),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterDecimalBucketUDF() {
+    IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16);
+    List<Row> results =
+        spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList();
+    Assert.assertEquals(1, results.size());
+    Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16)
+            .apply(new BigDecimal("11.11")),
+        results.get(0).getInt(0));
+  }
+
+  @Test
+  public void testRegisterBooleanBucketUDF() {
+    Assertions.assertThatThrownBy(() ->
+            IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessage("Cannot bucket by type: boolean");
+  }
+
+  @Test
+  public void testRegisterDoubleBucketUDF() {
+    Assertions.assertThatThrownBy(() ->
+            IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessage("Cannot bucket by type: double");
+  }
+
+  @Test
+  public void testRegisterFloatBucketUDF() {
+    Assertions.assertThatThrownBy(() ->
+            IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessage("Cannot bucket by type: float");
   }
 }