You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2023/01/17 15:08:09 UTC

[flink] branch release-1.15 updated: [FLINK-28526][python] Fix Python UDF to support time indicator inputs

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch release-1.15
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.15 by this push:
     new 4c7001c0dcf [FLINK-28526][python] Fix Python UDF to support time indicator inputs
4c7001c0dcf is described below

commit 4c7001c0dcff400ff2824bf3e75df184c686f2cd
Author: Dian Fu <di...@apache.org>
AuthorDate: Mon Jan 16 14:05:59 2023 +0800

    [FLINK-28526][python] Fix Python UDF to support time indicator inputs
    
    This closes #21686.
---
 flink-python/pyflink/table/tests/test_udf.py       | 43 ++++++++++++++++++++++
 .../plan/nodes/exec/utils/CommonPythonUtil.java    | 33 ++++++++++++-----
 2 files changed, 66 insertions(+), 10 deletions(-)

diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index 135f1d3ddc7..509989a0313 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -23,6 +23,7 @@ import unittest
 import pytest
 import pytz
 
+from pyflink.common import Row
 from pyflink.table import DataTypes, expressions as expr
 from pyflink.table.udf import ScalarFunction, udf
 from pyflink.testing import source_sink_utils
@@ -788,6 +789,48 @@ class PyFlinkStreamUserDefinedFunctionTests(UserDefinedFunctionTests,
         lines.sort()
         self.assertEqual(lines, ['1,2', '2,3', '3,4'])
 
+    def test_udf_with_rowtime_arguments(self):
+        from pyflink.common import WatermarkStrategy
+        from pyflink.common.typeinfo import Types
+        from pyflink.common.watermark_strategy import TimestampAssigner
+        from pyflink.datastream import StreamExecutionEnvironment
+        from pyflink.table import Schema, StreamTableEnvironment
+
+        class MyTimestampAssigner(TimestampAssigner):
+
+            def extract_timestamp(self, value, record_timestamp) -> int:
+                return int(value[0])
+
+        env = StreamExecutionEnvironment.get_execution_environment()
+        t_env = StreamTableEnvironment.create(env)
+
+        ds = env.from_collection(
+            [(1, 42, "a"), (2, 5, "a"), (3, 1000, "c"), (100, 1000, "c")],
+            Types.ROW_NAMED(["a", "b", "c"], [Types.LONG(), Types.INT(), Types.STRING()]))
+
+        ds = ds.assign_timestamps_and_watermarks(
+            WatermarkStrategy.for_monotonous_timestamps()
+            .with_timestamp_assigner(MyTimestampAssigner()))
+
+        table = t_env.from_data_stream(
+            ds,
+            Schema.new_builder()
+                  .column_by_metadata("rowtime", "TIMESTAMP_LTZ(3)")
+                  .watermark("rowtime", "SOURCE_WATERMARK()")
+                  .build())
+
+        @udf(result_type=DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.INT())]))
+        def inc(input_row):
+            return Row(input_row.b)
+
+        table_sink = source_sink_utils.TestAppendSink(
+            ['a'], [DataTypes.INT()])
+        t_env.register_table_sink("Results", table_sink)
+        table.map(inc).execute_insert("Results").wait()
+
+        actual = source_sink_utils.results()
+        self.assert_equals(actual, ['+I[42]', '+I[5]', '+I[1000]', '+I[1000]'])
+
 
 class PyFlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
                                            PyFlinkBatchTableTestCase):
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index a949ad2afa6..c26a1ad1dc6 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -47,6 +47,7 @@ import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
 import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
 import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
 import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.schema.TimeIndicatorRelDataType;
 import org.apache.flink.table.planner.plan.utils.AggregateInfo;
 import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
 import org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment;
@@ -70,10 +71,12 @@ import org.apache.flink.table.types.logical.StructuredType;
 
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.fun.SqlCastFunction;
 import org.apache.calcite.sql.type.SqlTypeName;
 
 import java.lang.reflect.Field;
@@ -425,20 +428,30 @@ public class CommonPythonUtil {
         for (RexNode operand : pythonRexCall.getOperands()) {
             if (operand instanceof RexCall) {
                 RexCall childPythonRexCall = (RexCall) operand;
-                PythonFunctionInfo argPythonInfo =
-                        createPythonFunctionInfo(childPythonRexCall, inputNodes);
-                inputs.add(argPythonInfo);
+                if (childPythonRexCall.getOperator() instanceof SqlCastFunction
+                        && childPythonRexCall.getOperands().get(0) instanceof RexInputRef
+                        && childPythonRexCall.getOperands().get(0).getType()
+                                instanceof TimeIndicatorRelDataType) {
+                    operand = childPythonRexCall.getOperands().get(0);
+                } else {
+                    PythonFunctionInfo argPythonInfo =
+                            createPythonFunctionInfo(childPythonRexCall, inputNodes);
+                    inputs.add(argPythonInfo);
+                    continue;
+                }
             } else if (operand instanceof RexLiteral) {
                 RexLiteral literal = (RexLiteral) operand;
                 inputs.add(convertLiteralToPython(literal, literal.getType().getSqlTypeName()));
+                continue;
+            }
+
+            assert operand instanceof RexInputRef;
+            if (inputNodes.containsKey(operand)) {
+                inputs.add(inputNodes.get(operand));
             } else {
-                if (inputNodes.containsKey(operand)) {
-                    inputs.add(inputNodes.get(operand));
-                } else {
-                    Integer inputOffset = inputNodes.size();
-                    inputs.add(inputOffset);
-                    inputNodes.put(operand, inputOffset);
-                }
+                Integer inputOffset = inputNodes.size();
+                inputs.add(inputOffset);
+                inputNodes.put(operand, inputOffset);
             }
         }
         return new PythonFunctionInfo((PythonFunction) functionDefinition, inputs.toArray());