You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/02/22 01:57:58 UTC

[flink] branch release-1.14 updated: [FLINK-25856][python][BP-1.14] Fix use of UserDefinedType in from_elements

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.14 by this push:
     new cfc54ca  [FLINK-25856][python][BP-1.14] Fix use of UserDefinedType in from_elements
cfc54ca is described below

commit cfc54caf4882a7339b02d06b3c43df64feb02ad5
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Mon Feb 21 21:41:08 2022 +0800

    [FLINK-25856][python][BP-1.14] Fix use of UserDefinedType in from_elements
    
    This closes #18864.
---
 .../table/tests/test_table_environment_api.py      |  74 ++++-
 flink-python/pyflink/table/types.py                |   9 +-
 .../flink/api/common/python/PythonBridgeUtils.java |  14 +-
 .../planner/utils/python/PythonTableUtils.scala    | 335 +++++++++++----------
 4 files changed, 268 insertions(+), 164 deletions(-)

diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py b/flink-python/pyflink/table/tests/test_table_environment_api.py
index 4c5dfe6..e47a686 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_api.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_api.py
@@ -38,7 +38,7 @@ from pyflink.table.catalog import ObjectPath, CatalogBaseTable
 from pyflink.table.explain_detail import ExplainDetail
 from pyflink.table.expressions import col, source_watermark
 from pyflink.table.table_descriptor import TableDescriptor
-from pyflink.table.types import RowType, Row
+from pyflink.table.types import RowType, Row, UserDefinedType
 from pyflink.table.udf import udf
 from pyflink.testing import source_sink_utils
 from pyflink.testing.test_case_utils import (
@@ -540,8 +540,80 @@ class StreamTableEnvironmentTests(TableEnvironmentTest, PyFlinkStreamTableTestCa
             self.assertEqual(expected_result, collected_result)
 
 
+class VectorUDT(UserDefinedType):
+
+    @classmethod
+    def sql_type(cls):
+        return DataTypes.ROW(
+            [
+                DataTypes.FIELD("type", DataTypes.TINYINT()),
+                DataTypes.FIELD("size", DataTypes.INT()),
+                DataTypes.FIELD("indices", DataTypes.ARRAY(DataTypes.INT())),
+                DataTypes.FIELD("values", DataTypes.ARRAY(DataTypes.DOUBLE())),
+            ]
+        )
+
+    @classmethod
+    def module(cls):
+        return "pyflink.ml.core.linalg"
+
+    def serialize(self, obj):
+        if isinstance(obj, DenseVector):
+            values = [float(v) for v in obj._values]
+            return 1, None, None, values
+        else:
+            raise TypeError("Cannot serialize %r of type %r".format(obj, type(obj)))
+
+    def deserialize(self, datum):
+        pass
+
+
+class DenseVector(object):
+    __UDT__ = VectorUDT()
+
+    def __init__(self, values):
+        self._values = values
+
+    def size(self) -> int:
+        return len(self._values)
+
+    def get(self, i: int):
+        return self._values[i]
+
+    def to_array(self):
+        return self._values
+
+    @property
+    def values(self):
+        return self._values
+
+    def __str__(self):
+        return "[" + ",".join([str(v) for v in self._values]) + "]"
+
+    def __repr__(self):
+        return "DenseVector([%s])" % (", ".join(str(i) for i in self._values))
+
+
 class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
 
+    def test_udt(self):
+        self.t_env.from_elements([
+            (DenseVector([1, 2, 3, 4]), 0., 1.),
+            (DenseVector([2, 2, 3, 4]), 0., 2.),
+            (DenseVector([3, 2, 3, 4]), 0., 3.),
+            (DenseVector([4, 2, 3, 4]), 0., 4.),
+            (DenseVector([5, 2, 3, 4]), 0., 5.),
+            (DenseVector([11, 2, 3, 4]), 1., 1.),
+            (DenseVector([12, 2, 3, 4]), 1., 2.),
+            (DenseVector([13, 2, 3, 4]), 1., 3.),
+            (DenseVector([14, 2, 3, 4]), 1., 4.),
+            (DenseVector([15, 2, 3, 4]), 1., 5.),
+        ],
+            DataTypes.ROW([
+                DataTypes.FIELD("features", VectorUDT()),
+                DataTypes.FIELD("label", DataTypes.DOUBLE()),
+                DataTypes.FIELD("weight", DataTypes.DOUBLE())]))
+
     def test_explain_with_multi_sinks(self):
         t_env = self.t_env
         source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index 1856d21..44458da 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2202,12 +2202,17 @@ def _create_type_verifier(data_type: DataType, name: str = None):
         verify_value = verify_varbinary
 
     elif isinstance(data_type, UserDefinedType):
-        verifier = _create_type_verifier(data_type.sql_type(), name=name)
+        sql_type = data_type.sql_type()
+        verifier = _create_type_verifier(sql_type, name=name)
 
         def verify_udf(obj):
             if not (hasattr(obj, '__UDT__') and obj.__UDT__ == data_type):
                 raise ValueError(new_msg("%r is not an instance of type %r" % (obj, data_type)))
-            verifier(data_type.to_sql_type(obj))
+            data = data_type.to_sql_type(obj)
+            if isinstance(sql_type, RowType):
+                # remove the RowKind value in the first position.
+                data = data[1:]
+            verifier(data)
 
         verify_value = verify_udf
 
diff --git a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
index baa83f0..7f83849 100644
--- a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
@@ -24,12 +24,16 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
 import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple;
 import org.apache.flink.api.java.typeutils.ListTypeInfo;
 import org.apache.flink.api.java.typeutils.MapTypeInfo;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.logical.ArrayType;
 import org.apache.flink.table.types.logical.DateType;
@@ -364,8 +368,16 @@ public final class PythonBridgeUtils {
                     && BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
                 // Serialization of float type with pickler loses precision.
                 return pickler.dumps(String.valueOf(obj));
-            } else {
+            } else if (dataType instanceof PickledByteArrayTypeInfo
+                    || dataType instanceof BasicTypeInfo) {
                 return pickler.dumps(obj);
+            } else {
+                // other typeinfos will use the corresponding serializer to serialize data.
+                TypeSerializer serializer = dataType.createSerializer(null);
+                ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
+                DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
+                serializer.serialize(obj, baosWrapper);
+                return pickler.dumps(baos.toByteArray());
             }
         }
     }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/python/PythonTableUtils.scala
index da39d48..288f63d 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/python/PythonTableUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/python/PythonTableUtils.scala
@@ -31,6 +31,7 @@ import org.apache.flink.api.java.io.CollectionInputFormat
 import org.apache.flink.api.java.tuple.Tuple
 import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo, TupleTypeInfo}
 import org.apache.flink.core.io.InputSplit
+import org.apache.flink.core.memory.{ByteArrayInputStreamWithPos, DataInputViewStreamWrapper}
 import org.apache.flink.table.api.{TableSchema, Types}
 import org.apache.flink.table.sources.InputFormatTableSource
 import org.apache.flink.types.{Row, RowKind}
@@ -53,7 +54,7 @@ object PythonTableUtils {
       data: java.util.List[Array[Object]],
       dataType: TypeInformation[Row],
       config: ExecutionConfig): InputFormat[Row, _] = {
-    val converter = convertTo(dataType)
+    val converter = convertTo(dataType, config)
     new CollectionInputFormat(data.map(converter(_).asInstanceOf[Row]),
       dataType.createSerializer(config))
   }
@@ -71,7 +72,7 @@ object PythonTableUtils {
     data: java.util.List[T],
     dataType: TypeInformation[T],
     config: ExecutionConfig): InputFormat[T, _] ={
-    val converter = convertTo(dataType)
+    val converter = convertTo(dataType, config)
     new CollectionInputFormat[T](data.map(converter(_).asInstanceOf[T]),
       dataType.createSerializer(config)
     )
@@ -81,190 +82,204 @@ object PythonTableUtils {
     * Creates a converter that converts `obj` to the type specified by the data type, or returns
     * null if the type of obj is unexpected because Python doesn't enforce the type.
     */
-  private def convertTo(dataType: TypeInformation[_]): Any => Any = dataType match {
-    case _ if dataType == Types.BOOLEAN => (obj: Any) => nullSafeConvert(obj) {
-      case b: Boolean => b
-    }
+  private def convertTo(dataType: TypeInformation[_], config: ExecutionConfig): Any => Any =
+    dataType match {
+      case _ if dataType == Types.BOOLEAN => (obj: Any) => nullSafeConvert(obj) {
+        case b: Boolean => b
+      }
 
-    case _ if dataType == Types.BYTE => (obj: Any) => nullSafeConvert(obj) {
-      case c: Byte => c
-      case c: Short => c.toByte
-      case c: Int => c.toByte
-      case c: Long => c.toByte
-    }
+      case _ if dataType == Types.BYTE => (obj: Any) => nullSafeConvert(obj) {
+        case c: Byte => c
+        case c: Short => c.toByte
+        case c: Int => c.toByte
+        case c: Long => c.toByte
+      }
 
-    case _ if dataType == Types.SHORT => (obj: Any) => nullSafeConvert(obj) {
-      case c: Byte => c.toShort
-      case c: Short => c
-      case c: Int => c.toShort
-      case c: Long => c.toShort
-    }
+      case _ if dataType == Types.SHORT => (obj: Any) => nullSafeConvert(obj) {
+        case c: Byte => c.toShort
+        case c: Short => c
+        case c: Int => c.toShort
+        case c: Long => c.toShort
+      }
 
-    case _ if dataType == Types.INT => (obj: Any) => nullSafeConvert(obj) {
-      case c: Byte => c.toInt
-      case c: Short => c.toInt
-      case c: Int => c
-      case c: Long => c.toInt
-    }
+      case _ if dataType == Types.INT => (obj: Any) => nullSafeConvert(obj) {
+        case c: Byte => c.toInt
+        case c: Short => c.toInt
+        case c: Int => c
+        case c: Long => c.toInt
+      }
 
-    case _ if dataType == Types.LONG => (obj: Any) => nullSafeConvert(obj) {
-      case c: Byte => c.toLong
-      case c: Short => c.toLong
-      case c: Int => c.toLong
-      case c: Long => c
-    }
+      case _ if dataType == Types.LONG => (obj: Any) => nullSafeConvert(obj) {
+        case c: Byte => c.toLong
+        case c: Short => c.toLong
+        case c: Int => c.toLong
+        case c: Long => c
+      }
 
-    case _ if dataType == Types.FLOAT => (obj: Any) => nullSafeConvert(obj) {
-      case c: Float => c
-      case c: Double => c.toFloat
-    }
+      case _ if dataType == Types.FLOAT => (obj: Any) => nullSafeConvert(obj) {
+        case c: Float => c
+        case c: Double => c.toFloat
+      }
 
-    case _ if dataType == Types.DOUBLE => (obj: Any) => nullSafeConvert(obj) {
-      case c: Float => c.toDouble
-      case c: Double => c
-    }
+      case _ if dataType == Types.DOUBLE => (obj: Any) => nullSafeConvert(obj) {
+        case c: Float => c.toDouble
+        case c: Double => c
+      }
 
-    case _ if dataType == Types.DECIMAL => (obj: Any) => nullSafeConvert(obj) {
-      case c: java.math.BigDecimal => c
-    }
+      case _ if dataType == Types.DECIMAL => (obj: Any) => nullSafeConvert(obj) {
+        case c: java.math.BigDecimal => c
+      }
 
-    case _ if dataType == Types.SQL_DATE => (obj: Any) => nullSafeConvert(obj) {
-      case c: Int =>
-        val millisLocal = c.toLong * 86400000
-        val millisUtc = millisLocal - getOffsetFromLocalMillis(millisLocal)
-        new Date(millisUtc)
-    }
+      case _ if dataType == Types.SQL_DATE => (obj: Any) => nullSafeConvert(obj) {
+        case c: Int =>
+          val millisLocal = c.toLong * 86400000
+          val millisUtc = millisLocal - getOffsetFromLocalMillis(millisLocal)
+          new Date(millisUtc)
+      }
 
-    case _ if dataType == Types.SQL_TIME => (obj: Any) => nullSafeConvert(obj) {
-      case c: Long => new Time(c / 1000)
-      case c: Int => new Time(c.toLong / 1000)
-    }
+      case _ if dataType == Types.SQL_TIME => (obj: Any) => nullSafeConvert(obj) {
+        case c: Long => new Time(c / 1000)
+        case c: Int => new Time(c.toLong / 1000)
+      }
 
-    case _ if dataType == Types.SQL_TIMESTAMP => (obj: Any) => nullSafeConvert(obj) {
-      case c: Long => new Timestamp(c / 1000)
-      case c: Int => new Timestamp(c.toLong / 1000)
-    }
+      case _ if dataType == Types.SQL_TIMESTAMP => (obj: Any) => nullSafeConvert(obj) {
+        case c: Long => new Timestamp(c / 1000)
+        case c: Int => new Timestamp(c.toLong / 1000)
+      }
 
-    case _ if dataType == org.apache.flink.api.common.typeinfo.Types.INSTANT =>
-      (obj: Any) => nullSafeConvert(obj) {
-        case c: Long => Instant.ofEpochMilli(c / 1000)
-        case c: Int => Instant.ofEpochMilli(c.toLong / 1000)
-    }
+      case _ if dataType == org.apache.flink.api.common.typeinfo.Types.INSTANT =>
+        (obj: Any) => nullSafeConvert(obj) {
+          case c: Long => Instant.ofEpochMilli(c / 1000)
+          case c: Int => Instant.ofEpochMilli(c.toLong / 1000)
+      }
 
-    case _ if dataType == Types.INTERVAL_MILLIS() => (obj: Any) => nullSafeConvert(obj) {
-      case c: Long => c / 1000
-      case c: Int => c.toLong / 1000
-    }
+      case _ if dataType == Types.INTERVAL_MILLIS() => (obj: Any) => nullSafeConvert(obj) {
+        case c: Long => c / 1000
+        case c: Int => c.toLong / 1000
+      }
 
-    case _ if dataType == Types.STRING => (obj: Any) => nullSafeConvert(obj) {
-      case _ => obj.toString
-    }
+      case _ if dataType == Types.STRING => (obj: Any) => nullSafeConvert(obj) {
+        case _ => obj.toString
+      }
+
+      case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO =>
+        (obj: Any) =>
+          nullSafeConvert(obj) {
+            case c: String => c.getBytes(StandardCharsets.UTF_8)
+            case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+          }
 
-    case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO =>
-      (obj: Any) =>
-        nullSafeConvert(obj) {
-          case c: String => c.getBytes(StandardCharsets.UTF_8)
-          case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+      case _: PrimitiveArrayTypeInfo[_] |
+           _: BasicArrayTypeInfo[_, _] |
+           _: ObjectArrayTypeInfo[_, _] =>
+        var boxed = false
+        val elementType = dataType match {
+          case p: PrimitiveArrayTypeInfo[_] =>
+            p.getComponentType
+          case b: BasicArrayTypeInfo[_, _] =>
+            boxed = true
+            b.getComponentInfo
+          case o: ObjectArrayTypeInfo[_, _] =>
+            boxed = true
+            o.getComponentInfo
+        }
+        val elementFromJava = convertTo(elementType, config)
+
+        (obj: Any) => nullSafeConvert(obj) {
+          case c: java.util.List[_] =>
+            createArray(elementType,
+                        c.size(),
+                        i => elementFromJava(c.get(i)),
+                        boxed)
+          case c if c.getClass.isArray =>
+            createArray(elementType,
+                        c.asInstanceOf[Array[_]].length,
+                        i => elementFromJava(c.asInstanceOf[Array[_]](i)),
+                        boxed)
         }
 
-    case _: PrimitiveArrayTypeInfo[_] |
-         _: BasicArrayTypeInfo[_, _] |
-         _: ObjectArrayTypeInfo[_, _] =>
-      var boxed = false
-      val elementType = dataType match {
-        case p: PrimitiveArrayTypeInfo[_] =>
-          p.getComponentType
-        case b: BasicArrayTypeInfo[_, _] =>
-          boxed = true
-          b.getComponentInfo
-        case o: ObjectArrayTypeInfo[_, _] =>
-          boxed = true
-          o.getComponentInfo
-      }
-      val elementFromJava = convertTo(elementType)
-
-      (obj: Any) => nullSafeConvert(obj) {
-        case c: java.util.List[_] =>
-          createArray(elementType,
-                      c.size(),
-                      i => elementFromJava(c.get(i)),
-                      boxed)
-        case c if c.getClass.isArray =>
-          createArray(elementType,
-                      c.asInstanceOf[Array[_]].length,
-                      i => elementFromJava(c.asInstanceOf[Array[_]](i)),
-                      boxed)
-      }
+      case m: MapTypeInfo[_, _] =>
+        val keyFromJava = convertTo(m.getKeyTypeInfo, config)
+        val valueFromJava = convertTo(m.getValueTypeInfo, config)
+
+        (obj: Any) => nullSafeConvert(obj) {
+          case javaMap: java.util.Map[_, _] =>
+            val map = new java.util.HashMap[Any, Any]
+            javaMap.forEach(new BiConsumer[Any, Any] {
+              override def accept(k: Any, v: Any): Unit =
+                map.put(keyFromJava(k), valueFromJava(v))
+            })
+            map
+        }
 
-    case m: MapTypeInfo[_, _] =>
-      val keyFromJava = convertTo(m.getKeyTypeInfo)
-      val valueFromJava = convertTo(m.getValueTypeInfo)
-
-      (obj: Any) => nullSafeConvert(obj) {
-        case javaMap: java.util.Map[_, _] =>
-          val map = new java.util.HashMap[Any, Any]
-          javaMap.forEach(new BiConsumer[Any, Any] {
-            override def accept(k: Any, v: Any): Unit =
-              map.put(keyFromJava(k), valueFromJava(v))
-          })
-          map
-      }
+      case rowType: RowTypeInfo =>
+        val fieldsFromJava = rowType.getFieldTypes.map(f => convertTo(f, config))
+
+        (obj: Any) => nullSafeConvert(obj) {
+          case c if c.getClass.isArray =>
+            val r = c.asInstanceOf[Array[_]]
+            if (r.length - 1 != rowType.getFieldTypes.length) {
+              throw new IllegalStateException(
+                s"Input row doesn't have expected number of values required by the schema. " +
+                  s"${rowType.getFieldTypes.length} fields are required while ${r.length - 1} " +
+                  s"values are provided."
+                )
+            }
 
-    case rowType: RowTypeInfo =>
-      val fieldsFromJava = rowType.getFieldTypes.map(f => convertTo(f))
-
-      (obj: Any) => nullSafeConvert(obj) {
-        case c if c.getClass.isArray =>
-          val r = c.asInstanceOf[Array[_]]
-          if (r.length - 1 != rowType.getFieldTypes.length) {
-            throw new IllegalStateException(
-              s"Input row doesn't have expected number of values required by the schema. " +
-                s"${rowType.getFieldTypes.length} fields are required while ${r.length - 1} " +
-                s"values are provided."
+            val row = new Row(r.length - 1)
+            row.setKind(RowKind.fromByteValue(r(0).asInstanceOf[Integer].byteValue()))
+            var i = 1
+            while (i < r.length) {
+              row.setField(i - 1, fieldsFromJava(i - 1)(r(i)))
+              i += 1
+            }
+            row
+        }
+
+      case tupleType: TupleTypeInfo[_] =>
+        val fieldsTypes: Array[TypeInformation[_]] =
+          new Array[TypeInformation[_]](tupleType.getArity)
+        for ( i <- 0 until tupleType.getArity) {
+          fieldsTypes(i) = tupleType.getTypeAt(i)
+        }
+
+        val fieldsFromJava: Array[Any => Any] = fieldsTypes.map(f => convertTo(f, config))
+
+        (obj: Any) => nullSafeConvert(obj) {
+          case c if c.getClass.isArray =>
+            val r = c.asInstanceOf[Array[_]]
+            if (r.length != tupleType.getArity) {
+              throw new IllegalStateException(
+                s"Input tuple doesn't have expected number of values required by the schema. " +
+                  s"${tupleType.getArity} fields are required while ${r.length} " +
+                  s"values are provided."
               )
-          }
+            }
 
-          val row = new Row(r.length - 1)
-          row.setKind(RowKind.fromByteValue(r(0).asInstanceOf[Integer].byteValue()))
-          var i = 1
-          while (i < r.length) {
-            row.setField(i - 1, fieldsFromJava(i - 1)(r(i)))
+          val tuple = Tuple.newInstance(r.length)
+          var i: Int = 0
+          while(i < r.length){
+            tuple.setField(fieldsFromJava(i)(r(i)), i)
             i += 1
           }
-          row
-      }
+          tuple
+        }
 
-    case tupleType: TupleTypeInfo[_] =>
-      val fieldsTypes: Array[TypeInformation[_]] = new Array[TypeInformation[_]](tupleType.getArity)
-      for ( i <- 0 until tupleType.getArity) {
-        fieldsTypes(i) = tupleType.getTypeAt(i)
-      }
-      
-      val fieldsFromJava: Array[Any => Any] = fieldsTypes.map(f => convertTo(f))
-      
-      (obj: Any) => nullSafeConvert(obj) {
-        case c if c.getClass.isArray =>
-          val r = c.asInstanceOf[Array[_]]
-          if (r.length != tupleType.getArity) {
-            throw new IllegalStateException(
-              s"Input tuple doesn't have expected number of values required by the schema. " +
-                s"${tupleType.getArity} fields are required while ${r.length} " +
-                s"values are provided."
-            )
+      // UserDefinedType
+      case _ => (obj: Any) => {
+        obj match {
+          case b: Array[Byte] => if (dataType.getTypeClass == classOf[Array[Byte]]) {
+            obj
+          } else {
+            val dataSerializer = dataType.createSerializer(config)
+            val bais = new ByteArrayInputStreamWithPos()
+            val baisWrapper = new DataInputViewStreamWrapper(bais)
+            bais.setBuffer(b, 0, b.length)
+            dataSerializer.deserialize(baisWrapper)
           }
-          
-        val tuple = Tuple.newInstance(r.length)
-        var i: Int = 0
-        while(i < r.length){
-          tuple.setField(fieldsFromJava(i)(r(i)), i)
-          i += 1
+          case _ => obj
         }
-        tuple
       }
-    
-
-    // UserDefinedType
-    case _ => (obj: Any) => obj
   }
 
   private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {