You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/12/12 03:25:30 UTC
[flink] branch master updated: [FLINK-21223][python] Support to specify the output types of Python UDFs via string
This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 6cc00a707b2 [FLINK-21223][python] Support to specify the output types of Python UDFs via string
6cc00a707b2 is described below
commit 6cc00a707b238facbf5bf88a9fd727c8f9daab89
Author: huangxingbo <hx...@apache.org>
AuthorDate: Wed Nov 16 17:56:28 2022 +0800
[FLINK-21223][python] Support to specify the output types of Python UDFs via string
This closes #21332.
---
.../table/operations/row_based_operations.md | 36 ++---
.../docs/dev/python/table/udfs/python_udfs.md | 33 ++--
.../python/table/udfs/vectorized_python_udfs.md | 19 ++-
.../table/operations/row_based_operations.md | 36 ++---
.../docs/dev/python/table/udfs/python_udfs.md | 33 ++--
.../python/table/udfs/vectorized_python_udfs.md | 19 ++-
.../pyflink/table/tests/test_pandas_udaf.py | 2 +-
flink-python/pyflink/table/tests/test_udaf.py | 20 +--
flink-python/pyflink/table/tests/test_udf.py | 174 +++++++++++++++++++++
flink-python/pyflink/table/tests/test_udtf.py | 2 +-
flink-python/pyflink/table/udf.py | 148 +++++++++++++-----
.../functions/python/PythonAggregateFunction.java | 80 +++++++++-
.../functions/python/PythonScalarFunction.java | 66 +++++++-
.../python/PythonTableAggregateFunction.java | 79 +++++++++-
.../functions/python/PythonTableFunction.java | 70 ++++++++-
15 files changed, 639 insertions(+), 178 deletions(-)
diff --git a/docs/content.zh/docs/dev/python/table/operations/row_based_operations.md b/docs/content.zh/docs/dev/python/table/operations/row_based_operations.md
index 59e9fa03a24..027b400ea60 100644
--- a/docs/content.zh/docs/dev/python/table/operations/row_based_operations.md
+++ b/docs/content.zh/docs/dev/python/table/operations/row_based_operations.md
@@ -35,7 +35,6 @@ The output will be flattened if the output type is a composite type.
from pyflink.common import Row
from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
-from pyflink.table.types import DataTypes
from pyflink.table.udf import udf
env_settings = EnvironmentSettings.in_batch_mode()
@@ -43,8 +42,7 @@ table_env = TableEnvironment.create(env_settings)
table = table_env.from_elements([(1, 'Hi'), (2, 'Hello')], ['id', 'data'])
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]))
+@udf(result_type='ROW<id BIGINT, data STRING>')
def func1(id: int, data: str) -> Row:
return Row(id, data * 2)
@@ -62,8 +60,7 @@ table.map(func1(col('id'), col('data'))).execute().print()
It also supports to take a Row object (containing all the columns of the input table) as input.
```python
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]))
+@udf(result_type='ROW<id BIGINT, data STRING>')
def func2(data: Row) -> Row:
return Row(data.id, data.data * 2)
@@ -85,9 +82,7 @@ It should be noted that the input type and output type should be pandas.DataFram
```python
import pandas as pd
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]),
- func_type='pandas')
+@udf(result_type='ROW<id BIGINT, data STRING>', func_type='pandas')
def func3(data: pd.DataFrame) -> pd.DataFrame:
res = pd.concat([data.id, data.data * 2], axis=1)
return res
@@ -109,14 +104,14 @@ Performs a `flat_map` operation with a python [table function]({{< ref "docs/dev
```python
from pyflink.common import Row
from pyflink.table.udf import udtf
-from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
+from pyflink.table import EnvironmentSettings, TableEnvironment
env_settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(env_settings)
table = table_env.from_elements([(1, 'Hi,Flink'), (2, 'Hello')], ['id', 'data'])
-@udtf(result_types=[DataTypes.INT(), DataTypes.STRING()])
+@udtf(result_types=['INT', 'STRING'])
def split(x: Row) -> Row:
for s in x.data.split(","):
yield x.id, s
@@ -154,7 +149,7 @@ Performs an `aggregate` operation with a python [general aggregate function]({{<
```python
from pyflink.common import Row
-from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
+from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
from pyflink.table.udf import AggregateFunction, udaf
@@ -180,14 +175,10 @@ class CountAndSumAggregateFunction(AggregateFunction):
accumulator[1] += other_acc[1]
def get_accumulator_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT()),
- DataTypes.FIELD("b", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT, b BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT()),
- DataTypes.FIELD("b", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT, b BIGINT>'
function = CountAndSumAggregateFunction()
agg = udaf(function,
@@ -221,9 +212,7 @@ table_env = TableEnvironment.create(env_settings)
t = table_env.from_elements([(1, 2), (2, 1), (1, 3)], ['a', 'b'])
pandas_udaf = udaf(lambda pd: (pd.b.mean(), pd.b.max()),
- result_type=DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.FLOAT()),
- DataTypes.FIELD("b", DataTypes.INT())]),
+ result_type='ROW<a FLOAT, b INT>',
func_type="pandas")
t.aggregate(pandas_udaf.alias("a", "b")) \
.select(col('a'), col('b')).execute().print()
@@ -250,7 +239,7 @@ Similar to `aggregate`, you have to close the `flat_aggregate` with a select sta
```python
from pyflink.common import Row
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udtaf, TableAggregateFunction
@@ -272,11 +261,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row.a
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
diff --git a/docs/content.zh/docs/dev/python/table/udfs/python_udfs.md b/docs/content.zh/docs/dev/python/table/udfs/python_udfs.md
index 7194eb00323..b56b6db7f2a 100644
--- a/docs/content.zh/docs/dev/python/table/udfs/python_udfs.md
+++ b/docs/content.zh/docs/dev/python/table/udfs/python_udfs.md
@@ -55,13 +55,13 @@ class HashCode(ScalarFunction):
settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(settings)
-hash_code = udf(HashCode(), result_type=DataTypes.BIGINT())
+hash_code = udf(HashCode(), result_type='BIGINT')
# 在 Python Table API 中使用 Python 自定义函数
my_table.select(col("string"), col("bigint"), hash_code(col("bigint")), call(hash_code, col("bigint")))
# 在 SQL API 中使用 Python 自定义函数
-table_env.create_temporary_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))
+table_env.create_temporary_function("hash_code", udf(HashCode(), result_type='BIGINT'))
table_env.sql_query("SELECT string, bigint, hash_code(bigint) FROM MyTable")
```
@@ -108,25 +108,25 @@ class Add(ScalarFunction):
add = udf(Add(), result_type=DataTypes.BIGINT())
# 方式二:普通 Python 函数
-@udf(result_type=DataTypes.BIGINT())
+@udf(result_type='BIGINT')
def add(i, j):
return i + j
# 方式三:lambda 函数
-add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())
+add = udf(lambda i, j: i + j, result_type='BIGINT')
# 方式四:callable 函数
class CallableAdd(object):
def __call__(self, i, j):
return i + j
-add = udf(CallableAdd(), result_type=DataTypes.BIGINT())
+add = udf(CallableAdd(), result_type='BIGINT')
# 方式五:partial 函数
def partial_add(i, j, k):
return i + j + k
-add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())
+add = udf(functools.partial(partial_add, k=1), result_type='BIGINT')
# 注册 Python 自定义函数
table_env.create_temporary_function("add", add)
@@ -160,14 +160,14 @@ table_env = TableEnvironment.create(env_settings)
my_table = ... # type: Table, table schema: [a: String]
# 注册 Python 表值函数
-split = udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()])
+split = udtf(Split(), result_types=['STRING', 'INT'])
# 在 Python Table API 中使用 Python 表值函数
my_table.join_lateral(split(col("a")).alias("word", "length"))
my_table.left_outer_join_lateral(split(col("a")).alias("word", "length"))
# 在 SQL API 中使用 Python 表值函数
-table_env.create_temporary_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))
+table_env.create_temporary_function("split", udtf(Split(), result_types=['STRING', 'INT']))
table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")
```
@@ -219,18 +219,18 @@ table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE
```python
# 方式一:生成器函数
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def generator_func(x):
yield 1
yield 2
# 方式二:返回迭代器
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def iterator_func(x):
return range(5)
# 方式三:返回可迭代子类
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def iterable_func(x):
result = [1, 2, 3]
return result
@@ -300,12 +300,10 @@ class WeightedAvg(AggregateFunction):
accumulator[1] -= weight
def get_result_type(self):
- return DataTypes.BIGINT()
+ return 'BIGINT'
def get_accumulator_type(self):
- return DataTypes.ROW([
- DataTypes.FIELD("f0", DataTypes.BIGINT()),
- DataTypes.FIELD("f1", DataTypes.BIGINT())])
+ return 'ROW<f0 BIGINT, f1 BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
@@ -475,11 +473,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row[0]
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
diff --git a/docs/content.zh/docs/dev/python/table/udfs/vectorized_python_udfs.md b/docs/content.zh/docs/dev/python/table/udfs/vectorized_python_udfs.md
index 5454af1057d..48e94645b89 100644
--- a/docs/content.zh/docs/dev/python/table/udfs/vectorized_python_udfs.md
+++ b/docs/content.zh/docs/dev/python/table/udfs/vectorized_python_udfs.md
@@ -49,11 +49,11 @@ under the License.
以下示例显示了如何定义自己的向量化 Python 标量函数,该函数计算两列的总和,并在查询中使用它:
```python
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udf
-@udf(result_type=DataTypes.BIGINT(), func_type="pandas")
+@udf(result_type='BIGINT', func_type="pandas")
def add(i, j):
return i + j
@@ -85,12 +85,12 @@ table_env.sql_query("SELECT add(bigint, bigint) FROM MyTable")
and `Over Window Aggregation` 使用它:
```python
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col, lit
from pyflink.table.udf import udaf
from pyflink.table.window import Tumble
-@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
+@udaf(result_type='FLOAT', func_type="pandas")
def mean_udaf(v):
return v.mean()
@@ -126,7 +126,6 @@ table_env.sql_query("""
以下示例显示了多种定义向量化 Python 聚合函数的方式。该函数需要两个类型为 bigint 的参数作为输入参数,并返回它们的最大值的和作为结果。
```python
-from pyflink.table import DataTypes
from pyflink.table.udf import AggregateFunction, udaf
# 方式一:扩展基类 `AggregateFunction`
@@ -152,26 +151,26 @@ class MaxAdd(AggregateFunction):
result += arg.max()
accumulator.append(result)
-max_add = udaf(MaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(MaxAdd(), result_type='BIGINT', func_type="pandas")
# 方式二:普通 Python 函数
-@udaf(result_type=DataTypes.BIGINT(), func_type="pandas")
+@udaf(result_type='BIGINT', func_type="pandas")
def max_add(i, j):
return i.max() + j.max()
# 方式三:lambda 函数
-max_add = udaf(lambda i, j: i.max() + j.max(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(lambda i, j: i.max() + j.max(), result_type='BIGINT', func_type="pandas")
# 方式四:callable 函数
class CallableMaxAdd(object):
def __call__(self, i, j):
return i.max() + j.max()
-max_add = udaf(CallableMaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(CallableMaxAdd(), result_type='BIGINT', func_type="pandas")
# 方式五:partial 函数
def partial_max_add(i, j, k):
return i.max() + j.max() + k
-max_add = udaf(functools.partial(partial_max_add, k=1), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(functools.partial(partial_max_add, k=1), result_type='BIGINT', func_type="pandas")
```
diff --git a/docs/content/docs/dev/python/table/operations/row_based_operations.md b/docs/content/docs/dev/python/table/operations/row_based_operations.md
index 640d8a05ddd..c11aa980d5c 100644
--- a/docs/content/docs/dev/python/table/operations/row_based_operations.md
+++ b/docs/content/docs/dev/python/table/operations/row_based_operations.md
@@ -35,7 +35,6 @@ The output will be flattened if the output type is a composite type.
from pyflink.common import Row
from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
-from pyflink.table.types import DataTypes
from pyflink.table.udf import udf
env_settings = EnvironmentSettings.in_batch_mode()
@@ -43,8 +42,7 @@ table_env = TableEnvironment.create(env_settings)
table = table_env.from_elements([(1, 'Hi'), (2, 'Hello')], ['id', 'data'])
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]))
+@udf(result_type='ROW<id BIGINT, data STRING>')
def func1(id: int, data: str) -> Row:
return Row(id, data * 2)
@@ -62,8 +60,7 @@ table.map(func1(col('id'), col('data'))).execute().print()
It also supports to take a Row object (containing all the columns of the input table) as input.
```python
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]))
+@udf(result_type='ROW<id BIGINT, data STRING>')
def func2(data: Row) -> Row:
return Row(data.id, data.data * 2)
@@ -85,9 +82,7 @@ It should be noted that the input type and output type should be pandas.DataFram
```python
import pandas as pd
-@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
- DataTypes.FIELD("data", DataTypes.STRING())]),
- func_type='pandas')
+@udf(result_type='ROW<id BIGINT, data STRING>', func_type='pandas')
def func3(data: pd.DataFrame) -> pd.DataFrame:
res = pd.concat([data.id, data.data * 2], axis=1)
return res
@@ -109,14 +104,14 @@ Performs a `flat_map` operation with a python [table function]({{< ref "docs/dev
```python
from pyflink.common import Row
from pyflink.table.udf import udtf
-from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
+from pyflink.table import EnvironmentSettings, TableEnvironment
env_settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(env_settings)
table = table_env.from_elements([(1, 'Hi,Flink'), (2, 'Hello')], ['id', 'data'])
-@udtf(result_types=[DataTypes.INT(), DataTypes.STRING()])
+@udtf(result_types=['INT', 'STRING'])
def split(x: Row) -> Row:
for s in x.data.split(","):
yield x.id, s
@@ -154,7 +149,7 @@ Performs an `aggregate` operation with a python [general aggregate function]({{<
```python
from pyflink.common import Row
-from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
+from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
from pyflink.table.udf import AggregateFunction, udaf
@@ -180,14 +175,10 @@ class CountAndSumAggregateFunction(AggregateFunction):
accumulator[1] += other_acc[1]
def get_accumulator_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT()),
- DataTypes.FIELD("b", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT, b BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT()),
- DataTypes.FIELD("b", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT, b BIGINT>'
function = CountAndSumAggregateFunction()
agg = udaf(function,
@@ -221,9 +212,7 @@ table_env = TableEnvironment.create(env_settings)
t = table_env.from_elements([(1, 2), (2, 1), (1, 3)], ['a', 'b'])
pandas_udaf = udaf(lambda pd: (pd.b.mean(), pd.b.max()),
- result_type=DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.FLOAT()),
- DataTypes.FIELD("b", DataTypes.INT())]),
+ result_type='ROW<a FLOAT, b INT>',
func_type="pandas")
t.aggregate(pandas_udaf.alias("a", "b")) \
.select(col('a'), col('b')).execute().print()
@@ -250,7 +239,7 @@ Similar to `aggregate`, you have to close the `flat_aggregate` with a select sta
```python
from pyflink.common import Row
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udtaf, TableAggregateFunction
@@ -272,11 +261,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row.a
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
diff --git a/docs/content/docs/dev/python/table/udfs/python_udfs.md b/docs/content/docs/dev/python/table/udfs/python_udfs.md
index 0fc3183c091..0fd08e47f7a 100644
--- a/docs/content/docs/dev/python/table/udfs/python_udfs.md
+++ b/docs/content/docs/dev/python/table/udfs/python_udfs.md
@@ -55,13 +55,13 @@ class HashCode(ScalarFunction):
settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(settings)
-hash_code = udf(HashCode(), result_type=DataTypes.BIGINT())
+hash_code = udf(HashCode(), result_type='BIGINT')
# use the Python function in Python Table API
my_table.select(col("string"), col("bigint"), hash_code(col("bigint")), call(hash_code, col("bigint")))
# use the Python function in SQL API
-table_env.create_temporary_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))
+table_env.create_temporary_function("hash_code", udf(HashCode(), result_type='BIGINT'))
table_env.sql_query("SELECT string, bigint, hash_code(bigint) FROM MyTable")
```
@@ -109,25 +109,25 @@ class Add(ScalarFunction):
add = udf(Add(), result_type=DataTypes.BIGINT())
# option 2: Python function
-@udf(result_type=DataTypes.BIGINT())
+@udf(result_type='BIGINT')
def add(i, j):
return i + j
# option 3: lambda function
-add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())
+add = udf(lambda i, j: i + j, result_type='BIGINT')
# option 4: callable function
class CallableAdd(object):
def __call__(self, i, j):
return i + j
-add = udf(CallableAdd(), result_type=DataTypes.BIGINT())
+add = udf(CallableAdd(), result_type='BIGINT')
# option 5: partial function
def partial_add(i, j, k):
return i + j + k
-add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())
+add = udf(functools.partial(partial_add, k=1), result_type='BIGINT')
# register the Python function
table_env.create_temporary_function("add", add)
@@ -162,14 +162,14 @@ table_env = TableEnvironment.create(env_settings)
my_table = ... # type: Table, table schema: [a: String]
# register the Python Table Function
-split = udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()])
+split = udtf(Split(), result_types=['STRING', 'INT'])
# use the Python Table Function in Python Table API
my_table.join_lateral(split(col("a")).alias("word", "length"))
my_table.left_outer_join_lateral(split(col("a")).alias("word", "length"))
# use the Python Table function in SQL API
-table_env.create_temporary_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))
+table_env.create_temporary_function("split", udtf(Split(), result_types=['STRING', 'INT']))
table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")
@@ -222,18 +222,18 @@ Like Python scalar functions, you can use the above five ways to define Python T
```python
# option 1: generator function
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def generator_func(x):
yield 1
yield 2
# option 2: return iterator
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def iterator_func(x):
return range(5)
# option 3: return iterable
-@udtf(result_types=DataTypes.BIGINT())
+@udtf(result_types='BIGINT')
def iterable_func(x):
result = [1, 2, 3]
return result
@@ -302,12 +302,10 @@ class WeightedAvg(AggregateFunction):
accumulator[1] -= weight
def get_result_type(self):
- return DataTypes.BIGINT()
+ return 'BIGINT'
def get_accumulator_type(self):
- return DataTypes.ROW([
- DataTypes.FIELD("f0", DataTypes.BIGINT()),
- DataTypes.FIELD("f1", DataTypes.BIGINT())])
+ return 'ROW<f0 BIGINT, f1 BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
@@ -477,11 +475,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row[0]
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.ROW(
- [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ return 'ROW<a BIGINT>'
env_settings = EnvironmentSettings.in_streaming_mode()
diff --git a/docs/content/docs/dev/python/table/udfs/vectorized_python_udfs.md b/docs/content/docs/dev/python/table/udfs/vectorized_python_udfs.md
index ccad98e65fc..15dc17c1b7e 100644
--- a/docs/content/docs/dev/python/table/udfs/vectorized_python_udfs.md
+++ b/docs/content/docs/dev/python/table/udfs/vectorized_python_udfs.md
@@ -49,11 +49,11 @@ The following example shows how to define your own vectorized Python scalar func
and use it in a query:
```python
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udf
-@udf(result_type=DataTypes.BIGINT(), func_type="pandas")
+@udf(result_type='BIGINT', func_type="pandas")
def add(i, j):
return i + j
@@ -84,12 +84,12 @@ The following example shows how to define your own vectorized Python aggregate f
and use it in `GroupBy Aggregation`, `GroupBy Window Aggregation` and `Over Window Aggregation`:
```python
-from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
+from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col, lit
from pyflink.table.udf import udaf
from pyflink.table.window import Tumble
-@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
+@udaf(result_type='FLOAT', func_type="pandas")
def mean_udaf(v):
return v.mean()
@@ -126,7 +126,6 @@ The following examples show the different ways to define a vectorized Python agg
which takes two columns of bigint as the inputs and returns the sum of the maximum of them as the result.
```python
-from pyflink.table import DataTypes
from pyflink.table.udf import AggregateFunction, udaf
# option 1: extending the base class `AggregateFunction`
@@ -152,27 +151,27 @@ class MaxAdd(AggregateFunction):
result += arg.max()
accumulator.append(result)
-max_add = udaf(MaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(MaxAdd(), result_type='BIGINT', func_type="pandas")
# option 2: Python function
-@udaf(result_type=DataTypes.BIGINT(), func_type="pandas")
+@udaf(result_type='BIGINT', func_type="pandas")
def max_add(i, j):
return i.max() + j.max()
# option 3: lambda function
-max_add = udaf(lambda i, j: i.max() + j.max(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(lambda i, j: i.max() + j.max(), result_type='BIGINT', func_type="pandas")
# option 4: callable function
class CallableMaxAdd(object):
def __call__(self, i, j):
return i.max() + j.max()
-max_add = udaf(CallableMaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(CallableMaxAdd(), result_type='BIGINT', func_type="pandas")
# option 5: partial function
def partial_max_add(i, j, k):
return i.max() + j.max() + k
-max_add = udaf(functools.partial(partial_max_add, k=1), result_type=DataTypes.BIGINT(), func_type="pandas")
+max_add = udaf(functools.partial(partial_max_add, k=1), result_type='BIGINT', func_type="pandas")
```
diff --git a/flink-python/pyflink/table/tests/test_pandas_udaf.py b/flink-python/pyflink/table/tests/test_pandas_udaf.py
index b19f4cd9190..1791add00a4 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udaf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udaf.py
@@ -343,7 +343,7 @@ class StreamPandasUDAFITTests(PyFlinkStreamTableTestCase):
super(StreamPandasUDAFITTests, cls).setUpClass()
cls.t_env.create_temporary_system_function("mean_udaf", mean_udaf)
max_add_min_udaf = udaf(lambda a: a.max() + a.min(),
- result_type=DataTypes.SMALLINT(),
+ result_type='SMALLINT',
func_type='pandas')
cls.t_env.create_temporary_system_function("max_add_min_udaf", max_add_min_udaf)
diff --git a/flink-python/pyflink/table/tests/test_udaf.py b/flink-python/pyflink/table/tests/test_udaf.py
index 1173825f1fa..cc922c08cb4 100644
--- a/flink-python/pyflink/table/tests/test_udaf.py
+++ b/flink-python/pyflink/table/tests/test_udaf.py
@@ -56,10 +56,10 @@ class CountAggregateFunction(AggregateFunction):
accumulator[0] = accumulator[0] + other_acc[0]
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.BIGINT()
+ return 'BIGINT'
class SumAggregateFunction(AggregateFunction):
@@ -81,10 +81,10 @@ class SumAggregateFunction(AggregateFunction):
accumulator[0] = accumulator[0] + other_acc[0]
def get_accumulator_type(self):
- return DataTypes.ARRAY(DataTypes.BIGINT())
+ return 'ARRAY<BIGINT>'
def get_result_type(self):
- return DataTypes.BIGINT()
+ return 'BIGINT'
class ConcatAggregateFunction(AggregateFunction):
@@ -107,12 +107,10 @@ class ConcatAggregateFunction(AggregateFunction):
accumulator[0].remove(args[0])
def get_accumulator_type(self):
- return DataTypes.ROW([
- DataTypes.FIELD("f0", DataTypes.ARRAY(DataTypes.STRING())),
- DataTypes.FIELD("f1", DataTypes.BIGINT())])
+ return 'ROW<f0 STRING, f1 BIGINT>'
def get_result_type(self):
- return DataTypes.STRING()
+ return 'STRING'
class ListViewConcatAggregateFunction(AggregateFunction):
@@ -169,12 +167,10 @@ class CountDistinctAggregateFunction(AggregateFunction):
accumulator[0][input_str] = None
def get_accumulator_type(self):
- return DataTypes.ROW([
- DataTypes.FIELD("f0", DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING())),
- DataTypes.FIELD("f1", DataTypes.BIGINT())])
+ return 'ROW<f0 MAP<STRING, STRING>, f1 BIGINT>'
def get_result_type(self):
- return DataTypes.BIGINT()
+ return 'BIGINT'
class CustomIterateAggregateFunction(AggregateFunction):
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index 80ecacbbb8c..699786b73d2 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -823,6 +823,180 @@ class PyFlinkEmbeddedThreadTests(UserDefinedFunctionTests, PyFlinkBatchTableTest
super(PyFlinkEmbeddedThreadTests, self).setUp()
self.t_env.get_config().set("python.execution-mode", "thread")
+ def test_all_data_types_string(self):
+ @udf(result_type='BOOLEAN')
+ def boolean_func(bool_param):
+ assert isinstance(bool_param, bool), 'bool_param of wrong type %s !' \
+ % type(bool_param)
+ return bool_param
+
+ @udf(result_type='TINYINT')
+ def tinyint_func(tinyint_param):
+ assert isinstance(tinyint_param, int), 'tinyint_param of wrong type %s !' \
+ % type(tinyint_param)
+ return tinyint_param
+
+ @udf(result_type='SMALLINT')
+ def smallint_func(smallint_param):
+ assert isinstance(smallint_param, int), 'smallint_param of wrong type %s !' \
+ % type(smallint_param)
+ assert smallint_param == 32767, 'smallint_param of wrong value %s' % smallint_param
+ return smallint_param
+
+ @udf(result_type='INT')
+ def int_func(int_param):
+ assert isinstance(int_param, int), 'int_param of wrong type %s !' \
+ % type(int_param)
+ assert int_param == -2147483648, 'int_param of wrong value %s' % int_param
+ return int_param
+
+ @udf(result_type='BIGINT')
+ def bigint_func(bigint_param):
+ assert isinstance(bigint_param, int), 'bigint_param of wrong type %s !' \
+ % type(bigint_param)
+ return bigint_param
+
+ @udf(result_type='BIGINT')
+ def bigint_func_none(bigint_param):
+ assert bigint_param is None, 'bigint_param %s should be None!' % bigint_param
+ return bigint_param
+
+ @udf(result_type='FLOAT')
+ def float_func(float_param):
+ assert isinstance(float_param, float) and float_equal(float_param, 1.23, 1e-6), \
+ 'float_param is wrong value %s !' % float_param
+ return float_param
+
+ @udf(result_type='DOUBLE')
+ def double_func(double_param):
+ assert isinstance(double_param, float) and float_equal(double_param, 1.98932, 1e-7), \
+ 'double_param is wrong value %s !' % double_param
+ return double_param
+
+ @udf(result_type='BYTES')
+ def bytes_func(bytes_param):
+ assert bytes_param == b'flink', \
+ 'bytes_param is wrong value %s !' % bytes_param
+ return bytes_param
+
+ @udf(result_type='STRING')
+ def str_func(str_param):
+ assert str_param == 'pyflink', \
+ 'str_param is wrong value %s !' % str_param
+ return str_param
+
+ @udf(result_type='DATE')
+ def date_func(date_param):
+ from datetime import date
+ assert date_param == date(year=2014, month=9, day=13), \
+ 'date_param is wrong value %s !' % date_param
+ return date_param
+
+ @udf(result_type='TIME')
+ def time_func(time_param):
+ from datetime import time
+ assert time_param == time(hour=12, minute=0, second=0, microsecond=123000), \
+ 'time_param is wrong value %s !' % time_param
+ return time_param
+
+ @udf(result_type='TIMESTAMP(3)')
+ def timestamp_func(timestamp_param):
+ from datetime import datetime
+ assert timestamp_param == datetime(2018, 3, 11, 3, 0, 0, 123000), \
+ 'timestamp_param is wrong value %s !' % timestamp_param
+ return timestamp_param
+
+ @udf(result_type='ARRAY<BIGINT>')
+ def array_func(array_param):
+ assert array_param == [[1, 2, 3]] or array_param == ((1, 2, 3),), \
+ 'array_param is wrong value %s !' % array_param
+ return array_param[0]
+
+ @udf(result_type='MAP<BIGINT, STRING>')
+ def map_func(map_param):
+ assert map_param == {1: 'flink', 2: 'pyflink'}, \
+ 'map_param is wrong value %s !' % map_param
+ return map_param
+
+ @udf(result_type='DECIMAL(38, 18)')
+ def decimal_func(decimal_param):
+ from decimal import Decimal
+ assert decimal_param == Decimal('1000000000000000000.050000000000000000'), \
+ 'decimal_param is wrong value %s !' % decimal_param
+ return decimal_param
+
+ @udf(result_type='DECIMAL(38, 18)')
+ def decimal_cut_func(decimal_param):
+ from decimal import Decimal
+ assert decimal_param == Decimal('1000000000000000000.059999999999999999'), \
+ 'decimal_param is wrong value %s !' % decimal_param
+ return decimal_param
+
+ sink_table = generate_random_table_name()
+ sink_table_ddl = f"""
+ CREATE TABLE {sink_table}(
+ a BIGINT, b BIGINT, c TINYINT, d BOOLEAN, e SMALLINT, f INT, g FLOAT, h DOUBLE, i BYTES,
+ j STRING, k DATE, l TIME, m TIMESTAMP(3), n ARRAY<BIGINT>, o MAP<BIGINT, STRING>,
+ p DECIMAL(38, 18), q DECIMAL(38, 18)) WITH ('connector'='test-sink')
+ """
+ self.t_env.execute_sql(sink_table_ddl)
+
+ import datetime
+ import decimal
+ t = self.t_env.from_elements(
+ [(1, None, 1, True, 32767, -2147483648, 1.23, 1.98932,
+ bytearray(b'flink'), 'pyflink', datetime.date(2014, 9, 13),
+ datetime.time(hour=12, minute=0, second=0, microsecond=123000),
+ datetime.datetime(2018, 3, 11, 3, 0, 0, 123000), [[1, 2, 3]],
+ {1: 'flink', 2: 'pyflink'}, decimal.Decimal('1000000000000000000.05'),
+ decimal.Decimal('1000000000000000000.05999999999999999899999999999'))],
+ DataTypes.ROW(
+ [DataTypes.FIELD("a", DataTypes.BIGINT()),
+ DataTypes.FIELD("b", DataTypes.BIGINT()),
+ DataTypes.FIELD("c", DataTypes.TINYINT()),
+ DataTypes.FIELD("d", DataTypes.BOOLEAN()),
+ DataTypes.FIELD("e", DataTypes.SMALLINT()),
+ DataTypes.FIELD("f", DataTypes.INT()),
+ DataTypes.FIELD("g", DataTypes.FLOAT()),
+ DataTypes.FIELD("h", DataTypes.DOUBLE()),
+ DataTypes.FIELD("i", DataTypes.BYTES()),
+ DataTypes.FIELD("j", DataTypes.STRING()),
+ DataTypes.FIELD("k", DataTypes.DATE()),
+ DataTypes.FIELD("l", DataTypes.TIME()),
+ DataTypes.FIELD("m", DataTypes.TIMESTAMP(3)),
+ DataTypes.FIELD("n", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))),
+ DataTypes.FIELD("o", DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())),
+ DataTypes.FIELD("p", DataTypes.DECIMAL(38, 18)),
+ DataTypes.FIELD("q", DataTypes.DECIMAL(38, 18))]))
+
+ t.select(
+ bigint_func(t.a),
+ bigint_func_none(t.b),
+ tinyint_func(t.c),
+ boolean_func(t.d),
+ smallint_func(t.e),
+ int_func(t.f),
+ float_func(t.g),
+ double_func(t.h),
+ bytes_func(t.i),
+ str_func(t.j),
+ date_func(t.k),
+ time_func(t.l),
+ timestamp_func(t.m),
+ array_func(t.n),
+ map_func(t.o),
+ decimal_func(t.p),
+ decimal_cut_func(t.q)) \
+ .execute_insert(sink_table).wait()
+ actual = source_sink_utils.results()
+ # Currently the sink result precision of DataTypes.TIME(precision) only supports 0.
+ self.assert_equals(actual,
+ ["+I[1, null, 1, true, 32767, -2147483648, 1.23, 1.98932, "
+ "[102, 108, 105, 110, 107], pyflink, 2014-09-13, 12:00:00.123, "
+ "2018-03-11T03:00:00.123, [1, 2, 3], "
+ "{1=flink, 2=pyflink}, 1000000000000000000.050000000000000000, "
+ "1000000000000000000.059999999999999999]"])
+
# test specify the input_types
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
diff --git a/flink-python/pyflink/table/tests/test_udtf.py b/flink-python/pyflink/table/tests/test_udtf.py
index 91ee519edb8..1053fdef5f4 100644
--- a/flink-python/pyflink/table/tests/test_udtf.py
+++ b/flink-python/pyflink/table/tests/test_udtf.py
@@ -152,7 +152,7 @@ class MultiEmit(TableFunction, unittest.TestCase):
yield x, i
-@udtf(result_types=[DataTypes.BIGINT()])
+@udtf(result_types=['bigint'])
def identity(x):
if x is not None:
from pyflink.common import Row
diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py
index 50ff1839b51..07840423321 100644
--- a/flink-python/pyflink/table/udf.py
+++ b/flink-python/pyflink/table/udf.py
@@ -16,7 +16,6 @@
# limitations under the License.
################################################################################
import abc
-import collections
import functools
import inspect
from typing import Union, List, Type, Callable, TypeVar, Generic, Iterable
@@ -180,7 +179,7 @@ class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]):
"""
raise RuntimeError("Method merge is not implemented")
- def get_result_type(self) -> DataType:
+ def get_result_type(self) -> Union[DataType, str]:
"""
Returns the DataType of the AggregateFunction's result.
@@ -189,7 +188,7 @@ class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]):
"""
raise RuntimeError("Method get_result_type is not implemented")
- def get_accumulator_type(self) -> DataType:
+ def get_accumulator_type(self) -> Union[DataType, str]:
"""
Returns the DataType of the AggregateFunction's accumulator.
@@ -323,15 +322,18 @@ class UserDefinedFunctionWrapper(object):
if input_types is not None:
from pyflink.table.types import RowType
- if not isinstance(input_types, collections.abc.Iterable) \
- or isinstance(input_types, RowType):
+ if isinstance(input_types, RowType):
+ input_types = input_types.field_types()
+ elif isinstance(input_types, (DataType, str)):
input_types = [input_types]
+ else:
+ input_types = list(input_types)
for input_type in input_types:
- if not isinstance(input_type, DataType):
+ if not isinstance(input_type, (DataType, str)):
raise TypeError(
- "Invalid input_type: input_type should be DataType but contains {}".format(
- input_type))
+ "Invalid input_type: input_type should be DataType or str but contains {}"
+ .format(input_type))
self._func = func
self._input_types = input_types
@@ -377,8 +379,11 @@ class UserDefinedFunctionWrapper(object):
raise TypeError("Unsupported func_type: %s." % self._func_type)
if self._input_types is not None:
- j_input_types = java_utils.to_jarray(
- gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
+ if isinstance(self._input_types[0], str):
+ j_input_types = java_utils.to_jarray(gateway.jvm.String, self._input_types)
+ else:
+ j_input_types = java_utils.to_jarray(
+ gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
else:
j_input_types = None
j_function_kind = get_python_function_kind()
@@ -408,15 +413,19 @@ class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper):
super(UserDefinedScalarFunctionWrapper, self).__init__(
func, input_types, func_type, deterministic, name)
- if not isinstance(result_type, DataType):
+ if not isinstance(result_type, (DataType, str)):
raise TypeError(
- "Invalid returnType: returnType should be DataType but is {}".format(result_type))
+ "Invalid returnType: returnType should be DataType or str but is {}".format(
+ result_type))
self._result_type = result_type
self._judf_placeholder = None
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
gateway = get_gateway()
- j_result_type = _to_java_data_type(self._result_type)
+ if isinstance(self._result_type, DataType):
+ j_result_type = _to_java_data_type(self._result_type)
+ else:
+ j_result_type = self._result_type
PythonScalarFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonScalarFunction
j_scalar_function = PythonScalarFunction(
@@ -444,23 +453,41 @@ class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
func, input_types, "general", deterministic, name)
from pyflink.table.types import RowType
- if not isinstance(result_types, collections.abc.Iterable) \
- or isinstance(result_types, RowType):
+ if isinstance(result_types, RowType):
+ # DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.INT()),
+ # DataTypes.FIELD("f1", DataTypes.BIGINT())])
+ result_types = result_types.field_types()
+ elif isinstance(result_types, str):
+ # ROW<f0 INT, f1 BIGINT>
+ result_types = result_types
+ elif isinstance(result_types, DataType):
+ # DataTypes.INT()
result_types = [result_types]
+ else:
+ # [DataTypes.INT(), DataTypes.BIGINT()]
+ result_types = list(result_types)
for result_type in result_types:
- if not isinstance(result_type, DataType):
+ if not isinstance(result_type, (DataType, str)):
raise TypeError(
- "Invalid result_type: result_type should be DataType but contains {}".format(
- result_type))
+ "Invalid result_type: result_type should be DataType or str but contains {}"
+ .format(result_type))
self._result_types = result_types
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
gateway = get_gateway()
- j_result_types = java_utils.to_jarray(gateway.jvm.DataType,
- [_to_java_data_type(i) for i in self._result_types])
- j_result_type = gateway.jvm.DataTypes.ROW(j_result_types)
+
+ if isinstance(self._result_types, str):
+ j_result_type = self._result_types
+ elif isinstance(self._result_types[0], DataType):
+ j_result_types = java_utils.to_jarray(
+ gateway.jvm.DataType, [_to_java_data_type(i) for i in self._result_types])
+ j_result_type = gateway.jvm.DataTypes.ROW(j_result_types)
+ else:
+ j_result_type = 'Row<{0}>'.format(','.join(
+ ['f{0} {1}'.format(i, result_type)
+ for i, result_type in enumerate(self._result_types)]))
PythonTableFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonTableFunction
j_table_function = PythonTableFunction(
@@ -491,33 +518,48 @@ class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
accumulator_type = func.get_accumulator_type()
if result_type is None:
result_type = func.get_result_type()
- if not isinstance(result_type, DataType):
+ if not isinstance(result_type, (DataType, str)):
raise TypeError(
- "Invalid returnType: returnType should be DataType but is {}".format(result_type))
+ "Invalid returnType: returnType should be DataType or str but is {}"
+ .format(result_type))
from pyflink.table.types import MapType
if func_type == 'pandas' and isinstance(result_type, MapType):
raise TypeError(
"Invalid returnType: Pandas UDAF doesn't support DataType type {} currently"
.format(result_type))
- if accumulator_type is not None and not isinstance(accumulator_type, DataType):
+ if accumulator_type is not None and not isinstance(accumulator_type, (DataType, str)):
raise TypeError(
- "Invalid accumulator_type: accumulator_type should be DataType but is {}".format(
- accumulator_type))
+ "Invalid accumulator_type: accumulator_type should be DataType or str but is {}"
+ .format(accumulator_type))
+ if (func_type == "general" and
+ not (isinstance(result_type, str) and (accumulator_type, str) or
+ isinstance(result_type, DataType) and isinstance(accumulator_type, DataType))):
+ raise TypeError("result_type and accumulator_type should be DataType or str "
+ "at the same time.")
self._result_type = result_type
self._accumulator_type = accumulator_type
self._is_table_aggregate = is_table_aggregate
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
if self._func_type == "pandas":
- from pyflink.table.types import DataTypes
- self._accumulator_type = DataTypes.ARRAY(self._result_type)
+ if isinstance(self._result_type, DataType):
+ from pyflink.table.types import DataTypes
+ self._accumulator_type = DataTypes.ARRAY(self._result_type)
+ else:
+ self._accumulator_type = 'ARRAY<{0}>'.format(self._result_type)
if j_input_types is not None:
gateway = get_gateway()
j_input_types = java_utils.to_jarray(
gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
- j_result_type = _to_java_data_type(self._result_type)
- j_accumulator_type = _to_java_data_type(self._accumulator_type)
+ if isinstance(self._result_type, DataType):
+ j_result_type = _to_java_data_type(self._result_type)
+ else:
+ j_result_type = self._result_type
+ if isinstance(self._accumulator_type, DataType):
+ j_accumulator_type = _to_java_data_type(self._accumulator_type)
+ else:
+ j_accumulator_type = self._accumulator_type
gateway = get_gateway()
if self._is_table_aggregate:
@@ -570,7 +612,8 @@ def _create_udtaf(f, input_types, result_type, accumulator_type, func_type, dete
def udf(f: Union[Callable, ScalarFunction, Type] = None,
- input_types: Union[List[DataType], DataType] = None, result_type: DataType = None,
+ input_types: Union[List[DataType], DataType, str, List[str]] = None,
+ result_type: Union[DataType, str] = None,
deterministic: bool = None, name: str = None, func_type: str = "general",
udf_type: str = None) -> Union[UserDefinedScalarFunctionWrapper, Callable]:
"""
@@ -586,6 +629,11 @@ def udf(f: Union[Callable, ScalarFunction, Type] = None,
... def add(i, j):
... return i + j
+ >>> # Specify result_type via string.
+ >>> @udf(result_type='BIGINT')
+ ... def add(i, j):
+ ... return i + j
+
>>> class SubtractOne(ScalarFunction):
... def eval(self, i):
... return i - 1
@@ -625,8 +673,9 @@ def udf(f: Union[Callable, ScalarFunction, Type] = None,
def udtf(f: Union[Callable, TableFunction, Type] = None,
- input_types: Union[List[DataType], DataType] = None,
- result_types: Union[List[DataType], DataType] = None, deterministic: bool = None,
+ input_types: Union[List[DataType], DataType, str, List[str]] = None,
+ result_types: Union[List[DataType], DataType, str, List[str]] = None,
+ deterministic: bool = None,
name: str = None) -> Union[UserDefinedTableFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined table function.
@@ -640,6 +689,18 @@ def udtf(f: Union[Callable, TableFunction, Type] = None,
... for i in range(e):
... yield s, i
+ >>> # Specify result_types via string
+ >>> @udtf(result_types=['BIGINT', 'BIGINT'])
+ ... def range_emit(s, e):
+ ... for i in range(e):
+ ... yield s, i
+
+ >>> # Specify result_types via row string
+ >>> @udtf(result_types='Row<a BIGINT, b BIGINT>')
+ ... def range_emit(s, e):
+ ... for i in range(e):
+ ... yield s, i
+
>>> class MultiEmit(TableFunction):
... def eval(self, i):
... return range(i)
@@ -665,8 +726,9 @@ def udtf(f: Union[Callable, TableFunction, Type] = None,
def udaf(f: Union[Callable, AggregateFunction, Type] = None,
- input_types: Union[List[DataType], DataType] = None, result_type: DataType = None,
- accumulator_type: DataType = None, deterministic: bool = None, name: str = None,
+ input_types: Union[List[DataType], DataType, str, List[str]] = None,
+ result_type: Union[DataType, str] = None, accumulator_type: Union[DataType, str] = None,
+ deterministic: bool = None, name: str = None,
func_type: str = "general") -> Union[UserDefinedAggregateFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined aggregate function.
@@ -679,6 +741,11 @@ def udaf(f: Union[Callable, AggregateFunction, Type] = None,
... def mean_udaf(v):
... return v.mean()
+ >>> # Specify result_type via string
+ >>> @udaf(result_type='FLOAT', func_type="pandas")
+ ... def mean_udaf(v):
+ ... return v.mean()
+
:param f: user-defined aggregate function.
:param input_types: optional, the input data types.
:param result_type: the result data type.
@@ -707,8 +774,10 @@ def udaf(f: Union[Callable, AggregateFunction, Type] = None,
def udtaf(f: Union[Callable, TableAggregateFunction, Type] = None,
- input_types: Union[List[DataType], DataType] = None, result_type: DataType = None,
- accumulator_type: DataType = None, deterministic: bool = None, name: str = None,
+ input_types: Union[List[DataType], DataType, str, List[str]] = None,
+ result_type: Union[DataType, str] = None,
+ accumulator_type: Union[DataType, str] = None,
+ deterministic: bool = None, name: str = None,
func_type: str = 'general') -> Union[UserDefinedAggregateFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined table aggregate function.
@@ -742,11 +811,10 @@ def udtaf(f: Union[Callable, TableAggregateFunction, Type] = None,
... self.accumulate(accumulator, other_acc[1])
...
... def get_accumulator_type(self):
- ... return DataTypes.ARRAY(DataTypes.BIGINT())
+ ... return 'ARRAY<BIGINT>'
...
... def get_result_type(self):
- ... return DataTypes.ROW(
- ... [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ ... return 'ROW<a BIGINT>'
>>> top2 = udtaf(Top2())
:param f: user-defined table aggregate function.
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonAggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonAggregateFunction.java
index 0443776903f..c518c3ddcbd 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonAggregateFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonAggregateFunction.java
@@ -27,6 +27,8 @@ import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.utils.TypeConversions;
+import java.util.Arrays;
+
/** The wrapper of user defined python aggregate function. */
@Internal
public class PythonAggregateFunction extends AggregateFunction implements PythonFunction {
@@ -35,14 +37,18 @@ public class PythonAggregateFunction extends AggregateFunction implements Python
private final String name;
private final byte[] serializedAggregateFunction;
- private final DataType[] inputTypes;
- private final DataType resultType;
- private final DataType accumulatorType;
private final PythonFunctionKind pythonFunctionKind;
private final boolean deterministic;
private final PythonEnv pythonEnv;
private final boolean takesRowAsInput;
+ private DataType[] inputTypes;
+ private String[] inputTypesString;
+ private DataType resultType;
+ private String resultTypeString;
+ private DataType accumulatorType;
+ private String accumulatorTypeString;
+
public PythonAggregateFunction(
String name,
byte[] serializedAggregateFunction,
@@ -53,11 +59,49 @@ public class PythonAggregateFunction extends AggregateFunction implements Python
boolean deterministic,
boolean takesRowAsInput,
PythonEnv pythonEnv) {
- this.name = name;
- this.serializedAggregateFunction = serializedAggregateFunction;
+ this(
+ name,
+ serializedAggregateFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
this.inputTypes = inputTypes;
this.resultType = resultType;
this.accumulatorType = accumulatorType;
+ }
+
+ public PythonAggregateFunction(
+ String name,
+ byte[] serializedAggregateFunction,
+ String[] inputTypesString,
+ String resultTypeString,
+ String accumulatorTypeString,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this(
+ name,
+ serializedAggregateFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
+ this.inputTypesString = inputTypesString;
+ this.resultTypeString = resultTypeString;
+ this.accumulatorTypeString = accumulatorTypeString;
+ }
+
+ public PythonAggregateFunction(
+ String name,
+ byte[] serializedAggregateFunction,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this.name = name;
+ this.serializedAggregateFunction = serializedAggregateFunction;
this.pythonFunctionKind = pythonFunctionKind;
this.deterministic = deterministic;
this.pythonEnv = pythonEnv;
@@ -106,20 +150,46 @@ public class PythonAggregateFunction extends AggregateFunction implements Python
@Override
public TypeInformation getResultType() {
+ if (resultType == null && resultTypeString != null) {
+ throw new RuntimeException(
+ "String format result type is not supported in old type system.");
+ }
return TypeConversions.fromDataTypeToLegacyInfo(resultType);
}
@Override
public TypeInformation getAccumulatorType() {
+ if (accumulatorType == null && accumulatorTypeString != null) {
+ throw new RuntimeException(
+ "String format accumulator type is not supported in old type system.");
+ }
return TypeConversions.fromDataTypeToLegacyInfo(accumulatorType);
}
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
+
+ if (inputTypesString != null) {
+ inputTypes =
+ (DataType[])
+ Arrays.stream(inputTypesString)
+ .map(typeFactory::createDataType)
+ .toArray();
+ }
+
if (inputTypes != null) {
builder.typedArguments(inputTypes);
}
+
+ if (resultType == null) {
+ resultType = typeFactory.createDataType(resultTypeString);
+ }
+
+ if (accumulatorType == null) {
+ accumulatorType = typeFactory.createDataType(accumulatorTypeString);
+ }
+
return builder.outputTypeStrategy(TypeStrategies.explicit(resultType))
.accumulatorTypeStrategy(TypeStrategies.explicit(accumulatorType))
.build();
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
index b3dfb12730e..fe3c1fc9f75 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
@@ -27,6 +27,7 @@ import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.utils.TypeConversions;
+import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -39,13 +40,16 @@ public class PythonScalarFunction extends ScalarFunction implements PythonFuncti
private final String name;
private final byte[] serializedScalarFunction;
- private final DataType[] inputTypes;
- private final DataType resultType;
private final PythonFunctionKind pythonFunctionKind;
private final boolean deterministic;
private final PythonEnv pythonEnv;
private final boolean takesRowAsInput;
+ private DataType[] inputTypes;
+ private String[] inputTypesString;
+ private DataType resultType;
+ private String resultTypeString;
+
public PythonScalarFunction(
String name,
byte[] serializedScalarFunction,
@@ -55,10 +59,46 @@ public class PythonScalarFunction extends ScalarFunction implements PythonFuncti
boolean deterministic,
boolean takesRowAsInput,
PythonEnv pythonEnv) {
- this.name = name;
- this.serializedScalarFunction = serializedScalarFunction;
+ this(
+ name,
+ serializedScalarFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
this.inputTypes = inputTypes;
this.resultType = resultType;
+ }
+
+ public PythonScalarFunction(
+ String name,
+ byte[] serializedScalarFunction,
+ String[] inputTypesString,
+ String resultTypeString,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this(
+ name,
+ serializedScalarFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
+ this.inputTypesString = inputTypesString;
+ this.resultTypeString = resultTypeString;
+ }
+
+ public PythonScalarFunction(
+ String name,
+ byte[] serializedScalarFunction,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this.name = name;
+ this.serializedScalarFunction = serializedScalarFunction;
this.pythonFunctionKind = pythonFunctionKind;
this.deterministic = deterministic;
this.pythonEnv = pythonEnv;
@@ -106,17 +146,35 @@ public class PythonScalarFunction extends ScalarFunction implements PythonFuncti
@Override
public TypeInformation getResultType(Class[] signature) {
+ if (resultType == null && resultTypeString != null) {
+ throw new RuntimeException(
+ "String format result type is not supported in old type system. The `register_function` is deprecated, please Use `create_temporary_system_function` instead.");
+ }
return TypeConversions.fromDataTypeToLegacyInfo(resultType);
}
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
+
+ if (inputTypesString != null) {
+ inputTypes =
+ (DataType[])
+ Arrays.stream(inputTypesString)
+ .map(typeFactory::createDataType)
+ .toArray();
+ }
+
if (inputTypes != null) {
final List<DataType> argumentDataTypes =
Stream.of(inputTypes).collect(Collectors.toList());
builder.typedArguments(argumentDataTypes);
}
+
+ if (resultType == null) {
+ resultType = typeFactory.createDataType(resultTypeString);
+ }
+
return builder.outputTypeStrategy(TypeStrategies.explicit(resultType)).build();
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
index 7e3c2db29d4..3d29d20ecb0 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
@@ -27,6 +27,8 @@ import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.utils.TypeConversions;
+import java.util.Arrays;
+
/** The wrapper of user defined python table aggregate function. */
@Internal
public class PythonTableAggregateFunction extends TableAggregateFunction implements PythonFunction {
@@ -35,14 +37,18 @@ public class PythonTableAggregateFunction extends TableAggregateFunction impleme
private final String name;
private final byte[] serializedTableAggregateFunction;
- private final DataType[] inputTypes;
- private final DataType resultType;
- private final DataType accumulatorType;
private final PythonFunctionKind pythonFunctionKind;
private final boolean deterministic;
private final PythonEnv pythonEnv;
private final boolean takesRowAsInput;
+ private DataType[] inputTypes;
+ private String[] inputTypesString;
+ private DataType resultType;
+ private String resultTypeString;
+ private DataType accumulatorType;
+ private String accumulatorTypeString;
+
public PythonTableAggregateFunction(
String name,
byte[] serializedTableAggregateFunction,
@@ -53,11 +59,49 @@ public class PythonTableAggregateFunction extends TableAggregateFunction impleme
boolean deterministic,
boolean takesRowAsInput,
PythonEnv pythonEnv) {
- this.name = name;
- this.serializedTableAggregateFunction = serializedTableAggregateFunction;
+ this(
+ name,
+ serializedTableAggregateFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
this.inputTypes = inputTypes;
this.resultType = resultType;
this.accumulatorType = accumulatorType;
+ }
+
+ public PythonTableAggregateFunction(
+ String name,
+ byte[] serializedTableAggregateFunction,
+ String[] inputTypesString,
+ String resultTypeString,
+ String accumulatorTypeString,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this(
+ name,
+ serializedTableAggregateFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
+ this.inputTypesString = inputTypesString;
+ this.resultTypeString = resultTypeString;
+ this.accumulatorTypeString = accumulatorTypeString;
+ }
+
+ public PythonTableAggregateFunction(
+ String name,
+ byte[] serializedTableAggregateFunction,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this.name = name;
+ this.serializedTableAggregateFunction = serializedTableAggregateFunction;
this.pythonFunctionKind = pythonFunctionKind;
this.deterministic = deterministic;
this.pythonEnv = pythonEnv;
@@ -106,20 +150,45 @@ public class PythonTableAggregateFunction extends TableAggregateFunction impleme
@Override
public TypeInformation getResultType() {
+ if (resultType == null && resultTypeString != null) {
+ throw new RuntimeException(
+ "String format result type is not supported in old type system. The `register_function` is deprecated, please Use `create_temporary_system_function` instead.");
+ }
return TypeConversions.fromDataTypeToLegacyInfo(resultType);
}
@Override
public TypeInformation getAccumulatorType() {
+ if (accumulatorType == null && accumulatorTypeString != null) {
+ throw new RuntimeException(
+ "String format result type is not supported in old type system. The `register_function` is deprecated, please Use `create_temporary_system_function` instead.");
+ }
return TypeConversions.fromDataTypeToLegacyInfo(accumulatorType);
}
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
+ if (inputTypesString != null) {
+ inputTypes =
+ (DataType[])
+ Arrays.stream(inputTypesString)
+ .map(typeFactory::createDataType)
+ .toArray();
+ }
+
if (inputTypes != null) {
builder.typedArguments(inputTypes);
}
+
+ if (resultType == null) {
+ resultType = typeFactory.createDataType(resultTypeString);
+ }
+
+ if (accumulatorType == null) {
+ accumulatorType = typeFactory.createDataType(accumulatorTypeString);
+ }
+
return builder.outputTypeStrategy(TypeStrategies.explicit(resultType))
.accumulatorTypeStrategy(TypeStrategies.explicit(accumulatorType))
.build();
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
index 29df9416efa..19da281bdb9 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
@@ -28,6 +28,7 @@ import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.types.Row;
+import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -39,14 +40,17 @@ public class PythonTableFunction extends TableFunction<Row> implements PythonFun
private static final long serialVersionUID = 1L;
private final String name;
- private final byte[] serializedScalarFunction;
- private final DataType[] inputTypes;
- private final DataType resultType;
+ private final byte[] serializedTableFunction;
private final PythonFunctionKind pythonFunctionKind;
private final boolean deterministic;
private final PythonEnv pythonEnv;
private final boolean takesRowAsInput;
+ private DataType[] inputTypes;
+ private String[] inputTypesString;
+ private DataType resultType;
+ private String resultTypeString;
+
public PythonTableFunction(
String name,
byte[] serializedScalarFunction,
@@ -56,10 +60,46 @@ public class PythonTableFunction extends TableFunction<Row> implements PythonFun
boolean deterministic,
boolean takesRowAsInput,
PythonEnv pythonEnv) {
- this.name = name;
- this.serializedScalarFunction = serializedScalarFunction;
+ this(
+ name,
+ serializedScalarFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
this.inputTypes = inputTypes;
this.resultType = resultType;
+ }
+
+ public PythonTableFunction(
+ String name,
+ byte[] serializedScalarFunction,
+ String[] inputTypesString,
+ String resultTypeString,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this(
+ name,
+ serializedScalarFunction,
+ pythonFunctionKind,
+ deterministic,
+ takesRowAsInput,
+ pythonEnv);
+ this.inputTypesString = inputTypesString;
+ this.resultTypeString = resultTypeString;
+ }
+
+ public PythonTableFunction(
+ String name,
+ byte[] serializedScalarFunction,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ boolean takesRowAsInput,
+ PythonEnv pythonEnv) {
+ this.name = name;
+ this.serializedTableFunction = serializedScalarFunction;
this.pythonFunctionKind = pythonFunctionKind;
this.deterministic = deterministic;
this.pythonEnv = pythonEnv;
@@ -73,7 +113,7 @@ public class PythonTableFunction extends TableFunction<Row> implements PythonFun
@Override
public byte[] getSerializedPythonFunction() {
- return serializedScalarFunction;
+ return serializedTableFunction;
}
@Override
@@ -107,17 +147,35 @@ public class PythonTableFunction extends TableFunction<Row> implements PythonFun
@Override
public TypeInformation<Row> getResultType() {
+ if (resultType == null && resultTypeString != null) {
+ throw new RuntimeException(
+ "String format result type is not supported in old type system. The `register_function` is deprecated, please Use `create_temporary_system_function` instead.");
+ }
return (TypeInformation<Row>) TypeConversions.fromDataTypeToLegacyInfo(resultType);
}
@Override
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
+
+ if (inputTypesString != null) {
+ inputTypes =
+ (DataType[])
+ Arrays.stream(inputTypesString)
+ .map(typeFactory::createDataType)
+ .toArray();
+ }
+
if (inputTypes != null) {
final List<DataType> argumentDataTypes =
Stream.of(inputTypes).collect(Collectors.toList());
builder.typedArguments(argumentDataTypes);
}
+
+ if (resultType == null) {
+ resultType = typeFactory.createDataType(resultTypeString);
+ }
+
return builder.outputTypeStrategy(TypeStrategies.explicit(resultType)).build();
}