You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@seatunnel.apache.org by zo...@apache.org on 2022/09/18 07:40:16 UTC

[incubator-seatunnel] branch dev updated: [Bug][Core] Fix the bug that can not convert array and map (#2750)

This is an automated email from the ASF dual-hosted git repository.

zongwen pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/incubator-seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 6db4d7595 [Bug][Core] Fix the bug that can not convert array and map (#2750)
6db4d7595 is described below

commit 6db4d7595d68e6ebbe8f436258f9948b9f930ca3
Author: TyrantLucifer <Ty...@gmail.com>
AuthorDate: Sun Sep 18 15:40:11 2022 +0800

    [Bug][Core] Fix the bug that can not convert array and map (#2750)
    
    * [Bug][Core] Fix the bug that can not convert array and map
    
    * [Hotfix][SeaTunnel-Schema] Fix the bug that can not parse decimal
    
    * [Bug][Connector-V2] Fix the bug that can not convert map and array correctly
    
    * [Bug][Core] Optimize code
---
 .../seatunnel/common/schema/SeaTunnelSchema.java   | 14 +++-
 .../common/serialization/InternalRowConverter.java | 79 +++++++++++-----------
 .../spark/common/utils/TypeConverterUtils.java     |  2 +-
 3 files changed, 53 insertions(+), 42 deletions(-)

diff --git a/seatunnel-connectors-v2/connector-common/src/main/java/org/apache/seatunnel/connectors/seatunnel/common/schema/SeaTunnelSchema.java b/seatunnel-connectors-v2/connector-common/src/main/java/org/apache/seatunnel/connectors/seatunnel/common/schema/SeaTunnelSchema.java
index 76d3d76ec..9a9e035fc 100644
--- a/seatunnel-connectors-v2/connector-common/src/main/java/org/apache/seatunnel/connectors/seatunnel/common/schema/SeaTunnelSchema.java
+++ b/seatunnel-connectors-v2/connector-common/src/main/java/org/apache/seatunnel/connectors/seatunnel/common/schema/SeaTunnelSchema.java
@@ -54,7 +54,15 @@ public class SeaTunnelSchema implements Serializable {
                 .substring(start + 1, end)
                 // replace the space between key and value
                 .replace(" ", "");
-        int index = genericType.indexOf(",");
+        int index;
+        if (genericType.startsWith(SqlType.DECIMAL.name())) {
+            // if map key is decimal, we should find the index of second ','
+            index = genericType.indexOf(",");
+            index = genericType.indexOf(",", index + 1);
+        } else {
+            // if map key is not decimal, we should find the index of first ','
+            index = genericType.indexOf(",");
+        }
         String keyGenericType = genericType.substring(0, index);
         String valueGenericType = genericType.substring(index + 1);
         return new String[]{keyGenericType, valueGenericType};
@@ -102,12 +110,12 @@ public class SeaTunnelSchema implements Serializable {
         type = type.toUpperCase();
         if (type.contains("<") || type.contains(">")) {
             // Map type or Array type
-            if (type.contains(SqlType.MAP.name())) {
+            if (type.startsWith(SqlType.MAP.name())) {
                 String[] genericTypes = parseMapGeneric(type);
                 keyGenericType = genericTypes[0];
                 valueGenericType = genericTypes[1];
                 type = SqlType.MAP.name();
-            } else {
+            } else if (type.startsWith(SqlType.ARRAY.name())) {
                 genericType = parseArrayGeneric(type);
                 type = SqlType.ARRAY.name();
             }
diff --git a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/serialization/InternalRowConverter.java b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/serialization/InternalRowConverter.java
index 31673c595..70545bcd1 100644
--- a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/serialization/InternalRowConverter.java
+++ b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/serialization/InternalRowConverter.java
@@ -38,7 +38,9 @@ import org.apache.spark.sql.catalyst.expressions.MutableLong;
 import org.apache.spark.sql.catalyst.expressions.MutableShort;
 import org.apache.spark.sql.catalyst.expressions.MutableValue;
 import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow;
+import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
 import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
 import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.unsafe.types.UTF8String;
 
@@ -47,9 +49,10 @@ import java.math.BigDecimal;
 import java.sql.Timestamp;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.function.BiFunction;
 
 public final class InternalRowConverter extends RowConverter<InternalRow> {
 
@@ -80,12 +83,19 @@ public final class InternalRowConverter extends RowConverter<InternalRow> {
             case TIMESTAMP:
                 return InstantConverterUtils.toEpochMicro(Timestamp.valueOf((LocalDateTime) field).toInstant());
             case MAP:
-                return convertMap((Map<?, ?>) field, (MapType<?, ?>) dataType, InternalRowConverter::convert);
+                return convertMap((Map<?, ?>) field, (MapType<?, ?>) dataType);
             case STRING:
                 return UTF8String.fromString((String) field);
             case DECIMAL:
                 return Decimal.apply((BigDecimal) field);
             case ARRAY:
+                // if string array, we need to covert every item in array from String to UTF8String
+                if (((ArrayType<?, ?>) dataType).getElementType().equals(BasicType.STRING_TYPE)) {
+                    String[] fields = (String[]) field;
+                    Object[] objects = Arrays.stream(fields).map(UTF8String::fromString).toArray();
+                    return ArrayData.toArrayData(objects);
+                }
+                // except string, now only support convert boolean int tinyint smallint bigint float double, because SeaTunnel Array only support these types
                 return ArrayData.toArrayData(field);
             default:
                 return field;
@@ -109,25 +119,35 @@ public final class InternalRowConverter extends RowConverter<InternalRow> {
         return new SpecificInternalRow(values);
     }
 
-    private static Object convertMap(Map<?, ?> mapData, MapType<?, ?> mapType, BiFunction<Object, SeaTunnelDataType<?>, Object> convertFunction) {
+    private static Object convertMap(Map<?, ?> mapData, MapType<?, ?> mapType) {
         if (mapData == null || mapData.size() == 0) {
-            return mapData;
+            return ArrayBasedMapData.apply(new Object[]{}, new Object[]{});
         }
-        switch (mapType.getValueType().getSqlType()) {
-            case MAP:
-            case ROW:
-            case DATE:
-            case TIME:
-            case TIMESTAMP:
-                Map<Object, Object> newMap = new HashMap<>(mapData.size());
-                mapData.forEach((key, value) -> {
-                    SeaTunnelDataType<?> valueType = mapType.getValueType();
-                    newMap.put(key, convertFunction.apply(value, valueType));
-                });
-                return newMap;
-            default:
-                return mapData;
+        SeaTunnelDataType<?> keyType = mapType.getKeyType();
+        SeaTunnelDataType<?> valueType = mapType.getValueType();
+        Map<Object, Object> newMap = new HashMap<>(mapData.size());
+        mapData.forEach((key, value) -> newMap.put(convert(key, keyType), convert(value, valueType)));
+        Object[] keys = newMap.keySet().toArray();
+        Object[] values = newMap.values().toArray();
+        return ArrayBasedMapData.apply(keys, values);
+    }
+
+    private static Object reconvertMap(MapData mapData, MapType<?, ?> mapType) {
+        if (mapData == null || mapData.numElements() == 0) {
+            return Collections.emptyMap();
         }
+        Map<Object, Object> newMap = new HashMap<>(mapData.numElements());
+        int num = mapData.numElements();
+        SeaTunnelDataType<?> keyType = mapType.getKeyType();
+        SeaTunnelDataType<?> valueType = mapType.getValueType();
+        Object[] keys = mapData.keyArray().toObjectArray(TypeConverterUtils.convert(keyType));
+        Object[] values = mapData.valueArray().toObjectArray(TypeConverterUtils.convert(valueType));
+        for (int i = 0; i < num; i++) {
+            keys[i] = reconvert(keys[i], keyType);
+            values[i] = reconvert(values[i], valueType);
+            newMap.put(keys[i], values[i]);
+        }
+        return newMap;
     }
 
     private static MutableValue createMutableValue(SeaTunnelDataType<?> dataType) {
@@ -170,11 +190,11 @@ public final class InternalRowConverter extends RowConverter<InternalRow> {
                 return LocalDate.ofEpochDay((int) field);
             case TIME:
                 // TODO: Support TIME Type
-                throw new RuntimeException("time type is not supported now, but will be supported in the future.");
+                throw new RuntimeException("SeaTunnel not support time type, it will be supported in the future.");
             case TIMESTAMP:
                 return Timestamp.from(InstantConverterUtils.ofEpochMicro((long) field)).toLocalDateTime();
             case MAP:
-                return convertMap((Map<?, ?>) field, (MapType<?, ?>) dataType, InternalRowConverter::reconvert);
+                return reconvertMap((MapData) field, (MapType<?, ?>) dataType);
             case STRING:
                 return field.toString();
             case DECIMAL:
@@ -182,24 +202,7 @@ public final class InternalRowConverter extends RowConverter<InternalRow> {
             case ARRAY:
                 ArrayData arrayData = (ArrayData) field;
                 BasicType<?> elementType = ((ArrayType<?, ?>) dataType).getElementType();
-                switch (elementType.getSqlType()) {
-                    case INT:
-                        return arrayData.toIntArray();
-                    case TINYINT:
-                        return arrayData.toByteArray();
-                    case SMALLINT:
-                        return arrayData.toShortArray();
-                    case BIGINT:
-                        return arrayData.toLongArray();
-                    case BOOLEAN:
-                        return arrayData.toBooleanArray();
-                    case FLOAT:
-                        return arrayData.toFloatArray();
-                    case DOUBLE:
-                        return arrayData.toDoubleArray();
-                    default:
-                        return arrayData.array();
-                }
+                return arrayData.toObjectArray(TypeConverterUtils.convert(elementType));
             default:
                 return field;
         }
diff --git a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/utils/TypeConverterUtils.java b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/utils/TypeConverterUtils.java
index 493fd9a9c..b28eedd24 100644
--- a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/utils/TypeConverterUtils.java
+++ b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/common/utils/TypeConverterUtils.java
@@ -125,7 +125,7 @@ public class TypeConverterUtils {
         }
         if (sparkType instanceof org.apache.spark.sql.types.MapType) {
             org.apache.spark.sql.types.MapType mapType = (org.apache.spark.sql.types.MapType) sparkType;
-            return new MapType<>(convert(mapType.valueType()), convert(mapType.valueType()));
+            return new MapType<>(convert(mapType.keyType()), convert(mapType.valueType()));
         }
         if (sparkType instanceof org.apache.spark.sql.types.DecimalType) {
             org.apache.spark.sql.types.DecimalType decimalType = (org.apache.spark.sql.types.DecimalType) sparkType;