You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2023/01/17 15:08:09 UTC
[flink] branch release-1.15 updated: [FLINK-28526][python] Fix Python UDF to support time indicator inputs
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch release-1.15
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.15 by this push:
new 4c7001c0dcf [FLINK-28526][python] Fix Python UDF to support time indicator inputs
4c7001c0dcf is described below
commit 4c7001c0dcff400ff2824bf3e75df184c686f2cd
Author: Dian Fu <di...@apache.org>
AuthorDate: Mon Jan 16 14:05:59 2023 +0800
[FLINK-28526][python] Fix Python UDF to support time indicator inputs
This closes #21686.
---
flink-python/pyflink/table/tests/test_udf.py | 43 ++++++++++++++++++++++
.../plan/nodes/exec/utils/CommonPythonUtil.java | 33 ++++++++++++-----
2 files changed, 66 insertions(+), 10 deletions(-)
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index 135f1d3ddc7..509989a0313 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -23,6 +23,7 @@ import unittest
import pytest
import pytz
+from pyflink.common import Row
from pyflink.table import DataTypes, expressions as expr
from pyflink.table.udf import ScalarFunction, udf
from pyflink.testing import source_sink_utils
@@ -788,6 +789,48 @@ class PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
lines.sort()
self.assertEqual(lines, ['1,2', '2,3', '3,4'])
+ def test_udf_with_rowtime_arguments(self):
+ from pyflink.common import WatermarkStrategy
+ from pyflink.common.typeinfo import Types
+ from pyflink.common.watermark_strategy import TimestampAssigner
+ from pyflink.datastream import StreamExecutionEnvironment
+ from pyflink.table import Schema, StreamTableEnvironment
+
+ class MyTimestampAssigner(TimestampAssigner):
+
+ def extract_timestamp(self, value, record_timestamp) -> int:
+ return int(value[0])
+
+ env = StreamExecutionEnvironment.get_execution_environment()
+ t_env = StreamTableEnvironment.create(env)
+
+ ds = env.from_collection(
+ [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
+ Types.ROW_NAMED(["a", "b", "c"], [Types.LONG(), Types.INT(), Types.STRING()]))
+
+ ds = ds.assign_timestamps_and_watermarks(
+ WatermarkStrategy.for_monotonous_timestamps()
+ .with_timestamp_assigner(MyTimestampAssigner()))
+
+ table = t_env.from_data_stream(
+ ds,
+ Schema.new_builder()
+ .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
+ .watermark("rowtime", "SOURCE_WATERMARK()")
+ .build())
+
+ @udf(result_type=DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.INT())]))
+ def inc(input_row):
+ return Row(input_row.b)
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a'], [DataTypes.INT()])
+ t_env.register_table_sink("Results", table_sink)
+ table.map(inc).execute_insert("Results").wait()
+
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ['+I[42]', '+I[5]', '+I[1000]', '+I[1000]'])
+
class PyFlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBatchTableTestCase):
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index a949ad2afa6..c26a1ad1dc6 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -47,6 +47,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.schema.TimeIndicatorRelDataType;
import org.apache.flink.table.planner.plan.utils.AggregateInfo;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment;
@@ -70,10 +71,12 @@ import org.apache.flink.table.types.logical.StructuredType;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import java.lang.reflect.Field;
@@ -425,20 +428,30 @@ public class CommonPythonUtil {
for (RexNode operand : pythonRexCall.getOperands()) {
if (operand instanceof RexCall) {
RexCall childPythonRexCall = (RexCall) operand;
- PythonFunctionInfo argPythonInfo =
- createPythonFunctionInfo(childPythonRexCall, inputNodes);
- inputs.add(argPythonInfo);
+ if (childPythonRexCall.getOperator() instanceof SqlCastFunction
+ && childPythonRexCall.getOperands().get(0) instanceof RexInputRef
+ && childPythonRexCall.getOperands().get(0).getType()
+ instanceof TimeIndicatorRelDataType) {
+ operand = childPythonRexCall.getOperands().get(0);
+ } else {
+ PythonFunctionInfo argPythonInfo =
+ createPythonFunctionInfo(childPythonRexCall, inputNodes);
+ inputs.add(argPythonInfo);
+ continue;
+ }
} else if (operand instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) operand;
inputs.add(convertLiteralToPython(literal, literal.getType().getSqlTypeName()));
+ continue;
+ }
+
+ assert operand instanceof RexInputRef;
+ if (inputNodes.containsKey(operand)) {
+ inputs.add(inputNodes.get(operand));
} else {
- if (inputNodes.containsKey(operand)) {
- inputs.add(inputNodes.get(operand));
- } else {
- Integer inputOffset = inputNodes.size();
- inputs.add(inputOffset);
- inputNodes.put(operand, inputOffset);
- }
+ Integer inputOffset = inputNodes.size();
+ inputs.add(inputOffset);
+ inputNodes.put(operand, inputOffset);
}
}
return new PythonFunctionInfo((PythonFunction) functionDefinition, inputs.toArray());