You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/07/17 12:15:08 UTC
[flink] branch master updated: [FLINK-18463][python] Make the
"input_types" parameter of the Python UDF/UDTF decorator optional.
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new ef46460 [FLINK-18463][python] Make the "input_types" parameter of the Python UDF/UDTF decorator optional.
ef46460 is described below
commit ef46460312b34284428398266db6f4cedf03052e
Author: Wei Zhong <we...@gmail.com>
AuthorDate: Fri Jul 17 16:58:39 2020 +0800
[FLINK-18463][python] Make the "input_types" parameter of the Python UDF/UDTF decorator optional.
This closes #12921.
---
docs/dev/table/python/python_udfs.md | 21 ++--
docs/dev/table/python/python_udfs.zh.md | 21 ++--
docs/dev/table/python/vectorized_python_udfs.md | 2 +-
docs/dev/table/python/vectorized_python_udfs.zh.md | 2 +-
flink-python/pyflink/table/table_environment.py | 21 ++--
.../pyflink/table/tests/test_dependency.py | 23 ++---
.../pyflink/table/tests/test_pandas_udf.py | 59 +++++------
flink-python/pyflink/table/tests/test_udf.py | 113 +++++++++------------
flink-python/pyflink/table/tests/test_udtf.py | 10 +-
flink-python/pyflink/table/udf.py | 41 ++++----
.../functions/python/PythonScalarFunction.java | 6 +-
.../functions/python/PythonTableFunction.java | 6 +-
12 files changed, 150 insertions(+), 175 deletions(-)
diff --git a/docs/dev/table/python/python_udfs.md b/docs/dev/table/python/python_udfs.md
index 98ce5f0..02ed376 100644
--- a/docs/dev/table/python/python_udfs.md
+++ b/docs/dev/table/python/python_udfs.md
@@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env)
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')
# register the Python function
-table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))
# use the Python function in Python Table API
my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)")
@@ -110,29 +110,28 @@ class Add(ScalarFunction):
def eval(self, i, j):
return i + j
-add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(Add(), result_type=DataTypes.BIGINT())
# option 2: Python function
-@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
+@udf(result_type=DataTypes.BIGINT())
def add(i, j):
return i + j
# option 3: lambda function
-add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())
# option 4: callable function
class CallableAdd(object):
def __call__(self, i, j):
return i + j
-add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(CallableAdd(), result_type=DataTypes.BIGINT())
# option 5: partial function
def partial_add(i, j, k):
return i + j + k
-add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()],
- DataTypes.BIGINT())
+add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())
# register the Python function
table_env.register_function("add", add)
@@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String]
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')
# register the Python Table Function
-table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()]))
+table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))
# use the Python Table Function in Python Table API
my_table.join_lateral("split(a) as (word, length)")
@@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T
{% highlight python %}
# option 1: generator function
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def generator_func(x):
yield 1
yield 2
# option 2: return iterator
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def iterator_func(x):
return range(5)
# option 3: return iterable
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def iterable_func(x):
result = [1, 2, 3]
return result
diff --git a/docs/dev/table/python/python_udfs.zh.md b/docs/dev/table/python/python_udfs.zh.md
index 1d83da5..c1618fb 100644
--- a/docs/dev/table/python/python_udfs.zh.md
+++ b/docs/dev/table/python/python_udfs.zh.md
@@ -52,7 +52,7 @@ table_env = BatchTableEnvironment.create(env)
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')
# register the Python function
-table_env.register_function("hash_code", udf(HashCode(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+table_env.register_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))
# use the Python function in Python Table API
my_table.select("string, bigint, bigint.hash_code(), hash_code(bigint)")
@@ -110,29 +110,28 @@ class Add(ScalarFunction):
def eval(self, i, j):
return i + j
-add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(Add(), result_type=DataTypes.BIGINT())
# option 2: Python function
-@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
+@udf(result_type=DataTypes.BIGINT())
def add(i, j):
return i + j
# option 3: lambda function
-add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())
# option 4: callable function
class CallableAdd(object):
def __call__(self, i, j):
return i + j
-add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
+add = udf(CallableAdd(), result_type=DataTypes.BIGINT())
# option 5: partial function
def partial_add(i, j, k):
return i + j + k
-add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()],
- DataTypes.BIGINT())
+add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())
# register the Python function
table_env.register_function("add", add)
@@ -163,7 +162,7 @@ my_table = ... # type: Table, table schema: [a: String]
table_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", '80m')
# register the Python Table Function
-table_env.register_function("split", udtf(Split(), DataTypes.STRING(), [DataTypes.STRING(), DataTypes.INT()]))
+table_env.register_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))
# use the Python Table Function in Python Table API
my_table.join_lateral("split(a) as (word, length)")
@@ -231,18 +230,18 @@ Like Python scalar functions, you can use the above five ways to define Python T
{% highlight python %}
# option 1: generator function
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def generator_func(x):
yield 1
yield 2
# option 2: return iterator
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def iterator_func(x):
return range(5)
# option 3: return iterable
-@udtf(input_types=DataTypes.BIGINT(), result_types=DataTypes.BIGINT())
+@udtf(result_types=DataTypes.BIGINT())
def iterable_func(x):
result = [1, 2, 3]
return result
diff --git a/docs/dev/table/python/vectorized_python_udfs.md b/docs/dev/table/python/vectorized_python_udfs.md
index ee5f03f..d467438 100644
--- a/docs/dev/table/python/vectorized_python_udfs.md
+++ b/docs/dev/table/python/vectorized_python_udfs.md
@@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func
and use it in a query:
{% highlight python %}
-@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas")
+@udf(result_type=DataTypes.BIGINT(), udf_type="pandas")
def add(i, j):
return i + j
diff --git a/docs/dev/table/python/vectorized_python_udfs.zh.md b/docs/dev/table/python/vectorized_python_udfs.zh.md
index b2d2ed9..c6381c1 100644
--- a/docs/dev/table/python/vectorized_python_udfs.zh.md
+++ b/docs/dev/table/python/vectorized_python_udfs.zh.md
@@ -48,7 +48,7 @@ The following example shows how to define your own vectorized Python scalar func
and use it in a query:
{% highlight python %}
-@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(), udf_type="pandas")
+@udf(result_type=DataTypes.BIGINT(), udf_type="pandas")
def add(i, j):
return i + j
diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py
index fbdead0..aa05b13 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -221,10 +221,9 @@ class TableEnvironment(object):
::
>>> table_env.create_temporary_system_function(
- ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
- >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
- ... result_type=DataTypes.BIGINT())
+ >>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.create_temporary_system_function("add", add)
@@ -233,7 +232,7 @@ class TableEnvironment(object):
... def eval(self, i):
... return i - 1
>>> table_env.create_temporary_system_function(
- ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param name: The name under which the function will be registered globally.
:param function: The function class containing the implementation. The function must have a
@@ -356,10 +355,9 @@ class TableEnvironment(object):
::
>>> table_env.create_temporary_function(
- ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
- >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
- ... result_type=DataTypes.BIGINT())
+ >>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.create_temporary_function("add", add)
@@ -368,7 +366,7 @@ class TableEnvironment(object):
... def eval(self, i):
... return i - 1
>>> table_env.create_temporary_function(
- ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param path: The path under which the function will be registered.
See also the :class:`~pyflink.table.TableEnvironment` class description for
@@ -1101,10 +1099,9 @@ class TableEnvironment(object):
::
>>> table_env.register_function(
- ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
- >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
- ... result_type=DataTypes.BIGINT())
+ >>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> table_env.register_function("add", add)
@@ -1113,7 +1110,7 @@ class TableEnvironment(object):
... def eval(self, i):
... return i - 1
>>> table_env.register_function(
- ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ ... "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
:param name: The name under which the function is registered.
:type name: str
diff --git a/flink-python/pyflink/table/tests/test_dependency.py b/flink-python/pyflink/table/tests/test_dependency.py
index 4b63784..2e5abb0 100644
--- a/flink-python/pyflink/table/tests/test_dependency.py
+++ b/flink-python/pyflink/table/tests/test_dependency.py
@@ -45,8 +45,7 @@ class DependencyTests(object):
from test_dependency_manage_lib import add_two
return add_two(i)
- self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -77,8 +76,7 @@ class FlinkBatchDependencyTests(PyFlinkBatchTableTestCase):
from test_dependency_manage_lib import add_two
return add_two(i)
- self.t_env.register_function("add_two", udf(plus_two, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ self.t_env.register_function("add_two", udf(plus_two, result_type=DataTypes.BIGINT()))
t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])\
.select("add_two(a), a")
@@ -107,8 +105,7 @@ class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCas
return i
self.t_env.register_function("check_requirements",
- udf(check_requirements, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ udf(check_requirements, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -153,9 +150,7 @@ class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCas
from python_package1 import plus
return plus(i, 1)
- self.t_env.register_function("add_one",
- udf(add_one, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ self.t_env.register_function("add_one", udf(add_one, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -181,8 +176,7 @@ class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCas
return i + int(f.read())
self.t_env.register_function("add_from_file",
- udf(add_from_file, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ udf(add_from_file, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -207,8 +201,7 @@ class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCas
return i
self.t_env.register_function("check_python_exec",
- udf(check_python_exec, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ udf(check_python_exec, result_type=DataTypes.BIGINT()))
def check_pyflink_gateway_disabled(i):
try:
@@ -222,8 +215,8 @@ class BlinkStreamDependencyTests(DependencyTests, PyFlinkBlinkStreamTableTestCas
return i
self.t_env.register_function("check_pyflink_gateway_disabled",
- udf(check_pyflink_gateway_disabled, DataTypes.BIGINT(),
- DataTypes.BIGINT()))
+ udf(check_pyflink_gateway_disabled,
+ result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
diff --git a/flink-python/pyflink/table/tests/test_pandas_udf.py b/flink-python/pyflink/table/tests/test_pandas_udf.py
index 03249ae..66941aa 100644
--- a/flink-python/pyflink/table/tests/test_pandas_udf.py
+++ b/flink-python/pyflink/table/tests/test_pandas_udf.py
@@ -34,7 +34,7 @@ class PandasUDFTests(unittest.TestCase):
def test_non_exist_udf_type(self):
with self.assertRaisesRegex(ValueError,
'The udf_type must be one of \'general, pandas\''):
- udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="non-exist")
+ udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="non-exist")
class PandasUDFITTests(object):
@@ -43,13 +43,13 @@ class PandasUDFITTests(object):
# pandas UDF
self.t_env.register_function(
"add_one",
- udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="pandas"))
+ udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="pandas"))
self.t_env.register_function("add", add)
# general Python UDF
self.t_env.register_function(
- "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
@@ -173,77 +173,72 @@ class PandasUDFITTests(object):
self.t_env.register_function(
"tinyint_func",
- udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT(), udf_type="pandas"))
+ udf(tinyint_func, result_type=DataTypes.TINYINT(), udf_type="pandas"))
self.t_env.register_function(
"smallint_func",
- udf(smallint_func, [DataTypes.SMALLINT()], DataTypes.SMALLINT(), udf_type="pandas"))
+ udf(smallint_func, result_type=DataTypes.SMALLINT(), udf_type="pandas"))
self.t_env.register_function(
"int_func",
- udf(int_func, [DataTypes.INT()], DataTypes.INT(), udf_type="pandas"))
+ udf(int_func, result_type=DataTypes.INT(), udf_type="pandas"))
self.t_env.register_function(
"bigint_func",
- udf(bigint_func, [DataTypes.BIGINT()], DataTypes.BIGINT(), udf_type="pandas"))
+ udf(bigint_func, result_type=DataTypes.BIGINT(), udf_type="pandas"))
self.t_env.register_function(
"boolean_func",
- udf(boolean_func, [DataTypes.BOOLEAN()], DataTypes.BOOLEAN(), udf_type="pandas"))
+ udf(boolean_func, result_type=DataTypes.BOOLEAN(), udf_type="pandas"))
self.t_env.register_function(
"float_func",
- udf(float_func, [DataTypes.FLOAT()], DataTypes.FLOAT(), udf_type="pandas"))
+ udf(float_func, result_type=DataTypes.FLOAT(), udf_type="pandas"))
self.t_env.register_function(
"double_func",
- udf(double_func, [DataTypes.DOUBLE()], DataTypes.DOUBLE(), udf_type="pandas"))
+ udf(double_func, result_type=DataTypes.DOUBLE(), udf_type="pandas"))
self.t_env.register_function(
"varchar_func",
- udf(varchar_func, [DataTypes.STRING()], DataTypes.STRING(), udf_type="pandas"))
+ udf(varchar_func, result_type=DataTypes.STRING(), udf_type="pandas"))
self.t_env.register_function(
"varbinary_func",
- udf(varbinary_func, [DataTypes.BYTES()], DataTypes.BYTES(), udf_type="pandas"))
+ udf(varbinary_func, result_type=DataTypes.BYTES(), udf_type="pandas"))
self.t_env.register_function(
"decimal_func",
- udf(decimal_func, [DataTypes.DECIMAL(38, 18)], DataTypes.DECIMAL(38, 18),
- udf_type="pandas"))
+ udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18), udf_type="pandas"))
self.t_env.register_function(
"date_func",
- udf(date_func, [DataTypes.DATE()], DataTypes.DATE(), udf_type="pandas"))
+ udf(date_func, result_type=DataTypes.DATE(), udf_type="pandas"))
self.t_env.register_function(
"time_func",
- udf(time_func, [DataTypes.TIME()], DataTypes.TIME(), udf_type="pandas"))
+ udf(time_func, result_type=DataTypes.TIME(), udf_type="pandas"))
self.t_env.register_function(
"timestamp_func",
- udf(timestamp_func, [DataTypes.TIMESTAMP(3)], DataTypes.TIMESTAMP(3),
- udf_type="pandas"))
+ udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3), udf_type="pandas"))
self.t_env.register_function(
"array_str_func",
- udf(array_func, [DataTypes.ARRAY(DataTypes.STRING())],
- DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas"))
+ udf(array_func, result_type=DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas"))
self.t_env.register_function(
"array_timestamp_func",
- udf(array_func, [DataTypes.ARRAY(DataTypes.TIMESTAMP(3))],
- DataTypes.ARRAY(DataTypes.TIMESTAMP(3)), udf_type="pandas"))
+ udf(array_func, result_type=DataTypes.ARRAY(DataTypes.TIMESTAMP(3)), udf_type="pandas"))
self.t_env.register_function(
"array_int_func",
- udf(array_func, [DataTypes.ARRAY(DataTypes.INT())],
- DataTypes.ARRAY(DataTypes.INT()), udf_type="pandas"))
+ udf(array_func, result_type=DataTypes.ARRAY(DataTypes.INT()), udf_type="pandas"))
self.t_env.register_function(
"nested_array_func",
- udf(nested_array_func, [DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))],
- DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas"))
+ udf(nested_array_func,
+ result_type=DataTypes.ARRAY(DataTypes.STRING()), udf_type="pandas"))
row_type = DataTypes.ROW(
[DataTypes.FIELD("f1", DataTypes.INT()),
@@ -252,7 +247,7 @@ class PandasUDFITTests(object):
DataTypes.FIELD("f4", DataTypes.ARRAY(DataTypes.INT()))])
self.t_env.register_function(
"row_func",
- udf(row_func, [row_type], row_type, udf_type="pandas"))
+ udf(row_func, result_type=row_type, udf_type="pandas"))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
@@ -350,8 +345,7 @@ class BlinkPandasUDFITTests(object):
self.t_env.register_function(
"local_zoned_timestamp_func",
udf(local_zoned_timestamp_func,
- [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)],
- DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3),
+ result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3),
udf_type="pandas"))
table_sink = source_sink_utils.TestAppendSink(
@@ -379,13 +373,13 @@ class BatchPandasUDFITTests(PyFlinkBatchTableTestCase):
def test_basic_functionality(self):
self.t_env.register_function(
"add_one",
- udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), udf_type="pandas"))
+ udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), udf_type="pandas"))
self.t_env.register_function("add", add)
# general Python UDF
self.t_env.register_function(
- "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
t = t.where("add_one(b) <= 3") \
@@ -406,8 +400,7 @@ class BlinkStreamPandasUDFITTests(PandasUDFITTests,
pass
-@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT(),
- udf_type='pandas')
+@udf(result_type=DataTypes.BIGINT(), udf_type='pandas')
def add(i, j):
return i + j
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index fbe0bd8..c9ca376 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -35,18 +35,18 @@ class UserDefinedFunctionTests(object):
self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'false')
# test lambda function
self.t_env.register_function(
- "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
# test Python ScalarFunction
self.t_env.register_function(
- "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
# test Python function
self.t_env.register_function("add", add)
# test callable function
self.t_env.register_function(
- "add_one_callable", udf(CallablePlus(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "add_one_callable", udf(CallablePlus(), result_type=DataTypes.BIGINT()))
def partial_func(col, param):
return col + param
@@ -55,7 +55,7 @@ class UserDefinedFunctionTests(object):
import functools
self.t_env.register_function(
"add_one_partial",
- udf(functools.partial(partial_func, param=1), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ udf(functools.partial(partial_func, param=1), result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f'],
@@ -74,9 +74,9 @@ class UserDefinedFunctionTests(object):
def test_chaining_scalar_function(self):
self.t_env.register_function(
- "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
self.t_env.register_function(
- "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
self.t_env.register_function("add", add)
table_sink = source_sink_utils.TestAppendSink(
@@ -95,7 +95,7 @@ class UserDefinedFunctionTests(object):
t1 = self.t_env.from_elements([(2, "Hi")], ['a', 'b'])
t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
- self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ self.t_env.register_function("f", udf(lambda i: i, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
@@ -111,7 +111,7 @@ class UserDefinedFunctionTests(object):
t1 = self.t_env.from_elements([(1, "Hi"), (2, "Hi")], ['a', 'b'])
t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
- self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ self.t_env.register_function("f", udf(lambda i: i, result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd'],
@@ -177,26 +177,11 @@ class UserDefinedFunctionTests(object):
self.t_env.register_function("udf_with_constant_params",
udf(udf_with_constant_params,
- input_types=[DataTypes.BIGINT(),
- DataTypes.BIGINT(),
- DataTypes.TINYINT(),
- DataTypes.SMALLINT(),
- DataTypes.INT(),
- DataTypes.BIGINT(),
- DataTypes.DECIMAL(38, 18),
- DataTypes.FLOAT(),
- DataTypes.DOUBLE(),
- DataTypes.BOOLEAN(),
- DataTypes.STRING(),
- DataTypes.DATE(),
- DataTypes.TIME(),
- DataTypes.TIMESTAMP(3)],
result_type=DataTypes.BIGINT()))
self.t_env.register_function(
"udf_with_all_constant_params", udf(lambda i, j: i + j,
- [DataTypes.BIGINT(), DataTypes.BIGINT()],
- DataTypes.BIGINT()))
+ result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
@@ -229,7 +214,7 @@ class UserDefinedFunctionTests(object):
def test_overwrite_builtin_function(self):
self.t_env.register_function(
"plus", udf(lambda i, j: i + j - 1,
- [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()))
+ result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -243,7 +228,7 @@ class UserDefinedFunctionTests(object):
def test_open(self):
self.t_env.get_config().get_configuration().set_string('python.metric.enabled', 'true')
self.t_env.register_function(
- "subtract", udf(Subtract(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract", udf(Subtract(), result_type=DataTypes.BIGINT()))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_table_sink("Results", table_sink)
@@ -256,9 +241,9 @@ class UserDefinedFunctionTests(object):
def test_udf_without_arguments(self):
self.t_env.register_function("one", udf(
- lambda: 1, input_types=[], result_type=DataTypes.BIGINT(), deterministic=True))
+ lambda: 1, result_type=DataTypes.BIGINT(), deterministic=True))
self.t_env.register_function("two", udf(
- lambda: 2, input_types=[], result_type=DataTypes.BIGINT(), deterministic=False))
+ lambda: 2, result_type=DataTypes.BIGINT(), deterministic=False))
table_sink = source_sink_utils.TestAppendSink(['a', 'b'],
[DataTypes.BIGINT(), DataTypes.BIGINT()])
@@ -363,59 +348,56 @@ class UserDefinedFunctionTests(object):
return decimal_param
self.t_env.register_function(
- "boolean_func", udf(boolean_func, [DataTypes.BOOLEAN()], DataTypes.BOOLEAN()))
+ "boolean_func", udf(boolean_func, result_type=DataTypes.BOOLEAN()))
self.t_env.register_function(
- "tinyint_func", udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT()))
+ "tinyint_func", udf(tinyint_func, result_type=DataTypes.TINYINT()))
self.t_env.register_function(
- "smallint_func", udf(smallint_func, [DataTypes.SMALLINT()], DataTypes.SMALLINT()))
+ "smallint_func", udf(smallint_func, result_type=DataTypes.SMALLINT()))
self.t_env.register_function(
- "int_func", udf(int_func, [DataTypes.INT()], DataTypes.INT()))
+ "int_func", udf(int_func, result_type=DataTypes.INT()))
self.t_env.register_function(
- "bigint_func", udf(bigint_func, [DataTypes.BIGINT()], DataTypes.BIGINT()))
+ "bigint_func", udf(bigint_func, result_type=DataTypes.BIGINT()))
self.t_env.register_function(
- "bigint_func_none", udf(bigint_func_none, [DataTypes.BIGINT()], DataTypes.BIGINT()))
+ "bigint_func_none", udf(bigint_func_none, result_type=DataTypes.BIGINT()))
self.t_env.register_function(
- "float_func", udf(float_func, [DataTypes.FLOAT()], DataTypes.FLOAT()))
+ "float_func", udf(float_func, result_type=DataTypes.FLOAT()))
self.t_env.register_function(
- "double_func", udf(double_func, [DataTypes.DOUBLE()], DataTypes.DOUBLE()))
+ "double_func", udf(double_func, result_type=DataTypes.DOUBLE()))
self.t_env.register_function(
- "bytes_func", udf(bytes_func, [DataTypes.BYTES()], DataTypes.BYTES()))
+ "bytes_func", udf(bytes_func, result_type=DataTypes.BYTES()))
self.t_env.register_function(
- "str_func", udf(str_func, [DataTypes.STRING()], DataTypes.STRING()))
+ "str_func", udf(str_func, result_type=DataTypes.STRING()))
self.t_env.register_function(
- "date_func", udf(date_func, [DataTypes.DATE()], DataTypes.DATE()))
+ "date_func", udf(date_func, result_type=DataTypes.DATE()))
self.t_env.register_function(
- "time_func", udf(time_func, [DataTypes.TIME()], DataTypes.TIME()))
+ "time_func", udf(time_func, result_type=DataTypes.TIME()))
self.t_env.register_function(
- "timestamp_func", udf(timestamp_func, [DataTypes.TIMESTAMP(3)], DataTypes.TIMESTAMP(3)))
+ "timestamp_func", udf(timestamp_func, result_type=DataTypes.TIMESTAMP(3)))
self.t_env.register_function(
- "array_func", udf(array_func, [DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT()))],
- DataTypes.ARRAY(DataTypes.BIGINT())))
+ "array_func", udf(array_func, result_type=DataTypes.ARRAY(DataTypes.BIGINT())))
self.t_env.register_function(
- "map_func", udf(map_func, [DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())],
- DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())))
+ "map_func", udf(map_func,
+ result_type=DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING())))
self.t_env.register_function(
- "decimal_func", udf(decimal_func, [DataTypes.DECIMAL(38, 18)],
- DataTypes.DECIMAL(38, 18)))
+ "decimal_func", udf(decimal_func, result_type=DataTypes.DECIMAL(38, 18)))
self.t_env.register_function(
- "decimal_cut_func", udf(decimal_cut_func, [DataTypes.DECIMAL(38, 18)],
- DataTypes.DECIMAL(38, 18)))
+ "decimal_cut_func", udf(decimal_cut_func, result_type=DataTypes.DECIMAL(38, 18)))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q'],
@@ -480,9 +462,9 @@ class UserDefinedFunctionTests(object):
t_env = self.t_env
t_env.create_temporary_system_function(
- "add_one_func", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "add_one_func", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
t_env.create_temporary_function(
- "subtract_one_func", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one_func", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
self.assert_equals(t_env.list_user_defined_functions(),
['add_one_func', 'subtract_one_func'])
@@ -505,9 +487,9 @@ class PyFlinkBatchUserDefinedFunctionTests(PyFlinkBatchTableTestCase):
def test_chaining_scalar_function(self):
self.t_env.register_function(
- "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "add_one", udf(lambda i: i + 1, result_type=DataTypes.BIGINT()))
self.t_env.register_function(
- "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ "subtract_one", udf(SubtractOne(), result_type=DataTypes.BIGINT()))
self.t_env.register_function("add", add)
t = self.t_env.from_elements([(1, 2, 1), (2, 5, 2), (3, 1, 3)], ['a', 'b', 'c'])\
@@ -520,43 +502,42 @@ class PyFlinkBatchUserDefinedFunctionTests(PyFlinkBatchTableTestCase):
class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBlinkStreamTableTestCase):
def test_deterministic(self):
- add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
+ add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
self.assertTrue(add_one._deterministic)
- add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False)
+ add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), deterministic=False)
self.assertFalse(add_one._deterministic)
- subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
+ subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
self.assertTrue(subtract_one._deterministic)
with self.assertRaises(ValueError, msg="Inconsistent deterministic: False and True"):
- udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False)
+ udf(SubtractOne(), result_type=DataTypes.BIGINT(), deterministic=False)
self.assertTrue(add._deterministic)
- @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), deterministic=False)
+ @udf(result_type=DataTypes.BIGINT(), deterministic=False)
def non_deterministic_udf(i):
return i
self.assertFalse(non_deterministic_udf._deterministic)
def test_name(self):
- add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
+ add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
self.assertEqual("<lambda>", add_one._name)
- add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), name="add_one")
+ add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT(), name="add_one")
self.assertEqual("add_one", add_one._name)
- subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
+ subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT())
self.assertEqual("SubtractOne", subtract_one._name)
- subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(),
- name="subtract_one")
+ subtract_one = udf(SubtractOne(), result_type=DataTypes.BIGINT(), name="subtract_one")
self.assertEqual("subtract_one", subtract_one._name)
self.assertEqual("add", add._name)
- @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), name="named")
+ @udf(result_type=DataTypes.BIGINT(), name="named")
def named_udf(i):
return i
@@ -597,8 +578,7 @@ class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
self.t_env.register_function(
"local_zoned_timestamp_func",
udf(local_zoned_timestamp_func,
- [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)],
- DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)))
+ result_type=DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)))
table_sink = source_sink_utils.TestAppendSink(
['a'], [DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)])
@@ -620,6 +600,7 @@ class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
pass
+# test specify the input_types
@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
def add(i, j):
return i + j
diff --git a/flink-python/pyflink/table/tests/test_udtf.py b/flink-python/pyflink/table/tests/test_udtf.py
index 76223a6..829b891 100644
--- a/flink-python/pyflink/table/tests/test_udtf.py
+++ b/flink-python/pyflink/table/tests/test_udtf.py
@@ -31,14 +31,12 @@ class UserDefinedTableFunctionTests(object):
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_function(
- "multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()],
- [DataTypes.BIGINT(), DataTypes.BIGINT()]))
+ "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))
self.t_env.register_function("condition_multi_emit", condition_multi_emit)
self.t_env.register_function(
- "multi_num", udf(MultiNum(), [DataTypes.BIGINT()],
- DataTypes.BIGINT()))
+ "multi_num", udf(MultiNum(), result_type=DataTypes.BIGINT()))
t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
t = t.join_lateral("multi_emit(a, multi_num(b)) as (x, y)") \
@@ -55,8 +53,7 @@ class UserDefinedTableFunctionTests(object):
[DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])
self.t_env.register_function(
- "multi_emit", udtf(MultiEmit(), [DataTypes.BIGINT(), DataTypes.BIGINT()],
- [DataTypes.BIGINT(), DataTypes.BIGINT()]))
+ "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))
t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
self.t_env.register_table("MyTable", t)
@@ -115,6 +112,7 @@ class MultiEmit(TableFunction, unittest.TestCase):
yield x, i
+# test specify the input_types
@udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
result_types=DataTypes.BIGINT())
def condition_multi_emit(x, y):
diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py
index b6b3e7e..3ad002f 100644
--- a/flink-python/pyflink/table/udf.py
+++ b/flink-python/pyflink/table/udf.py
@@ -160,14 +160,15 @@ class UserDefinedFunctionWrapper(object):
"Invalid function: not a function or callable (__call__ is not defined): {0}"
.format(type(func)))
- if not isinstance(input_types, collections.Iterable):
- input_types = [input_types]
+ if input_types is not None:
+ if not isinstance(input_types, collections.Iterable):
+ input_types = [input_types]
- for input_type in input_types:
- if not isinstance(input_type, DataType):
- raise TypeError(
- "Invalid input_type: input_type should be DataType but contains {}".format(
- input_type))
+ for input_type in input_types:
+ if not isinstance(input_type, DataType):
+ raise TypeError(
+ "Invalid input_type: input_type should be DataType but contains {}".format(
+ input_type))
self._func = func
self._input_types = input_types
@@ -228,8 +229,11 @@ class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper):
import cloudpickle
serialized_func = cloudpickle.dumps(func)
- j_input_types = utils.to_jarray(gateway.jvm.TypeInformation,
- [_to_java_type(i) for i in self._input_types])
+ if self._input_types is not None:
+ j_input_types = utils.to_jarray(
+ gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types])
+ else:
+ j_input_types = None
j_result_type = _to_java_type(self._result_type)
j_function_kind = get_python_function_kind(self._udf_type)
PythonScalarFunction = gateway.jvm \
@@ -280,8 +284,11 @@ class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
serialized_func = cloudpickle.dumps(func)
gateway = get_gateway()
- j_input_types = utils.to_jarray(gateway.jvm.TypeInformation,
- [_to_java_type(i) for i in self._input_types])
+ if self._input_types is not None:
+ j_input_types = utils.to_jarray(
+ gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types])
+ else:
+ j_input_types = None
j_result_types = utils.to_jarray(gateway.jvm.TypeInformation,
[_to_java_type(i) for i in self._result_types])
@@ -327,8 +334,8 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non
>>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
- >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
- ... result_type=DataTypes.BIGINT())
+ >>> # The input_types is optional.
+ >>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
@@ -339,7 +346,7 @@ def udf(f=None, input_types=None, result_type=None, deterministic=None, name=Non
:param f: lambda function or user-defined function.
:type f: function or UserDefinedFunction or type
- :param input_types: the input data types.
+ :param input_types: optional, the input data types.
:type input_types: list[DataType] or DataType
:param result_type: the result data type.
:type result_type: DataType
@@ -375,8 +382,8 @@ def udtf(f=None, input_types=None, result_types=None, deterministic=None, name=N
Example:
::
- >>> @udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
- ... result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
+ >>> # The input_types is optional.
+ >>> @udtf(result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
... def range_emit(s, e):
... for i in range(e):
... yield s, i
@@ -388,7 +395,7 @@ def udtf(f=None, input_types=None, result_types=None, deterministic=None, name=N
:param f: user-defined table function.
:type f: function or UserDefinedFunction or type
- :param input_types: the input data types.
+ :param input_types: optional, the input data types.
:type input_types: list[DataType] or DataType
:param result_types: the result data types.
:type result_types: list[DataType] or DataType
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
index 72df5e3..d00eb51 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonScalarFunction.java
@@ -91,7 +91,11 @@ public class PythonScalarFunction extends ScalarFunction implements PythonFuncti
@Override
public TypeInformation[] getParameterTypes(Class[] signature) {
- return inputTypes;
+ if (inputTypes != null) {
+ return inputTypes;
+ } else {
+ return super.getParameterTypes(signature);
+ }
}
@Override
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
index 3560170..2ed57af 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableFunction.java
@@ -93,7 +93,11 @@ public class PythonTableFunction extends TableFunction<Row> implements PythonFun
@Override
public TypeInformation[] getParameterTypes(Class[] signature) {
- return inputTypes;
+ if (inputTypes != null) {
+ return inputTypes;
+ } else {
+ return super.getParameterTypes(signature);
+ }
}
@Override