You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2021/05/20 10:08:08 UTC

[flink] branch master updated: [FLINK-22451][table] Support (*) as argument of UDF in Table API (#15768)

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new a85d7f4  [FLINK-22451][table] Support (*) as argument of UDF in Table API (#15768)
a85d7f4 is described below

commit a85d7f4ee13523c05a467e12043c5e96d99333fb
Author: Yi Tang <ss...@gmail.com>
AuthorDate: Thu May 20 18:07:48 2021 +0800

    [FLINK-22451][table] Support (*) as argument of UDF in Table API (#15768)
---
 docs/content.zh/docs/dev/table/functions/udfs.md   | 56 ++++++++++++++++++++++
 docs/content/docs/dev/table/functions/udfs.md      | 56 ++++++++++++++++++++++
 .../resolver/rules/ExpandColumnFunctionsRule.java  |  6 +++
 .../rules/StarReferenceFlatteningRule.java         | 10 ++++
 .../resolver/ExpressionResolverTest.java           | 18 +++++++
 .../planner/runtime/stream/table/CalcITCase.scala  | 55 +++++++++++++++++++++
 6 files changed, 201 insertions(+)

diff --git a/docs/content.zh/docs/dev/table/functions/udfs.md b/docs/content.zh/docs/dev/table/functions/udfs.md
index b898ef3..108f53e 100644
--- a/docs/content.zh/docs/dev/table/functions/udfs.md
+++ b/docs/content.zh/docs/dev/table/functions/udfs.md
@@ -168,6 +168,62 @@ env.createTemporarySystemFunction("SubstringFunction", new SubstringFunction(tru
 {{< /tab >}}
 {{< /tabs >}}
 
+你可以在 Table API 中使用 `*` 表达式作为函数的一个参数,它将被扩展为该表所有的列作为函数对应位置的参数。
+
+{{< tabs "101c5f48-f5a3-4e9a-b8ef-2fdd21a9e007" >}}
+{{< tab "Java" >}}
+```java
+import org.apache.flink.table.api.*;
+import org.apache.flink.table.functions.ScalarFunction;
+import static org.apache.flink.table.api.Expressions.*;
+
+public static class MyConcatFunction extends ScalarFunction {
+  public String eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object... fields) {
+    return Arrays.stream(fields)
+        .map(Object::toString)
+        .collect(Collectors.joining(","));
+  }
+}
+
+TableEnvironment env = TableEnvironment.create(...);
+
+// 使用 $("*") 作为函数的参数,如果 MyTable 有 3 列 (a, b, c),
+// 它们都将会被传给 MyConcatFunction。
+env.from("MyTable").select(call(MyConcatFunction.class, $("*")));
+
+// 它等价于显式地将所有列传给 MyConcatFunction。
+env.from("MyTable").select(call(MyConcatFunction.class, $("a"), $("b"), $("c")));
+
+```
+{{< /tab >}}
+{{< tab "Scala" >}}
+```scala
+import org.apache.flink.table.api._
+import org.apache.flink.table.functions.ScalarFunction
+
+import scala.annotation.varargs
+
+class MyConcatFunction extends ScalarFunction {
+  @varargs
+  def eval(@DataTypeHint(inputGroup = InputGroup.ANY) row: AnyRef*): String = {
+    row.map(f => f.toString).mkString(",")
+  }
+}
+
+val env = TableEnvironment.create(...)
+
+// 使用 $"*" 作为函数的参数,如果 MyTable 有 3 个列 (a, b, c),
+// 它们都将会被传给 MyConcatFunction。
+env.from("MyTable").select(call(classOf[MyConcatFunction], $"*"));
+
+// 它等价于显式地将所有列传给 MyConcatFunction。
+env.from("MyTable").select(call(classOf[MyConcatFunction], $"a", $"b", $"c"));
+
+```
+
+{{< /tab >}}
+{{< /tabs >}}
+
 {{< top >}}
 
 开发指南
diff --git a/docs/content/docs/dev/table/functions/udfs.md b/docs/content/docs/dev/table/functions/udfs.md
index d93e678..c3feb24 100644
--- a/docs/content/docs/dev/table/functions/udfs.md
+++ b/docs/content/docs/dev/table/functions/udfs.md
@@ -171,6 +171,62 @@ env.createTemporarySystemFunction("SubstringFunction", new SubstringFunction(tru
 {{< /tab >}}
 {{< /tabs >}}
 
+You can use star `*` expression as one argument of the function call to act as a wildcard in Table API,
+all columns in the table will be passed to the function at the corresponding position.
+
+{{< tabs "64dd4129-6313-4904-b7e7-a1a0535822e9" >}}
+{{< tab "Java" >}}
+```java
+import org.apache.flink.table.api.*;
+import org.apache.flink.table.functions.ScalarFunction;
+import static org.apache.flink.table.api.Expressions.*;
+
+public static class MyConcatFunction extends ScalarFunction {
+  public String eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object... fields) {
+    return Arrays.stream(fields)
+        .map(Object::toString)
+        .collect(Collectors.joining(","));
+  }
+}
+
+TableEnvironment env = TableEnvironment.create(...);
+
+// call function with $("*"), if MyTable has 3 fields (a, b, c),
+// all of them will be passed to MyConcatFunction.
+env.from("MyTable").select(call(MyConcatFunction.class, $("*")));
+
+// it's equal to call function with explicitly selecting all columns.
+env.from("MyTable").select(call(MyConcatFunction.class, $("a"), $("b"), $("c")));
+
+```
+{{< /tab >}}
+{{< tab "Scala" >}}
+```scala
+import org.apache.flink.table.api._
+import org.apache.flink.table.functions.ScalarFunction
+
+import scala.annotation.varargs
+
+class MyConcatFunction extends ScalarFunction {
+  @varargs
+  def eval(@DataTypeHint(inputGroup = InputGroup.ANY) row: AnyRef*): String = {
+    row.map(f => f.toString).mkString(",")
+  }
+}
+
+val env = TableEnvironment.create(...)
+
+// call function with $"*", if MyTable has 3 fields (a, b, c),
+// all of them will be passed to MyConcatFunction.
+env.from("MyTable").select(call(classOf[MyConcatFunction], $"*"));
+
+// it's equal to call function with explicitly selecting all columns.
+env.from("MyTable").select(call(classOf[MyConcatFunction], $"a", $"b", $"c"));
+
+```
+{{< /tab >}}
+{{< /tabs >}}
+
 {{< top >}}
 
 Implementation Guide
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ExpandColumnFunctionsRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ExpandColumnFunctionsRule.java
index dc613c9..4348768 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ExpandColumnFunctionsRule.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ExpandColumnFunctionsRule.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.table.api.ValidationException;
 import org.apache.flink.table.expressions.Expression;
 import org.apache.flink.table.expressions.ExpressionUtils;
+import org.apache.flink.table.expressions.FieldReferenceExpression;
 import org.apache.flink.table.expressions.UnresolvedCallExpression;
 import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
 import org.apache.flink.table.expressions.ValueLiteralExpression;
@@ -146,6 +147,11 @@ final class ExpandColumnFunctionsRule implements ResolverRule {
         }
 
         @Override
+        public List<UnresolvedReferenceExpression> visit(FieldReferenceExpression fieldReference) {
+            return Collections.singletonList(unresolvedRef(fieldReference.getName()));
+        }
+
+        @Override
         public List<UnresolvedReferenceExpression> visit(ValueLiteralExpression valueLiteral) {
             return ExpressionUtils.extractValue(valueLiteral, Integer.class)
                     .map(i -> Collections.singletonList(inputFieldReferences.get(i - 1)))
diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/StarReferenceFlatteningRule.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/StarReferenceFlatteningRule.java
index 9cc3cf8..4770f50 100644
--- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/StarReferenceFlatteningRule.java
+++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/StarReferenceFlatteningRule.java
@@ -20,6 +20,7 @@ package org.apache.flink.table.expressions.resolver.rules;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.table.expressions.Expression;
+import org.apache.flink.table.expressions.UnresolvedCallExpression;
 import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
 
 import java.util.ArrayList;
@@ -58,6 +59,15 @@ final class StarReferenceFlatteningRule implements ResolverRule {
         }
 
         @Override
+        public List<Expression> visit(UnresolvedCallExpression unresolvedCall) {
+            final List<Expression> newArgs =
+                    unresolvedCall.getChildren().stream()
+                            .flatMap(e -> e.accept(this).stream())
+                            .collect(Collectors.toList());
+            return singletonList(unresolvedCall.replaceArgs(newArgs));
+        }
+
+        @Override
         protected List<Expression> defaultMethod(Expression expression) {
             return singletonList(expression);
         }
diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/resolver/ExpressionResolverTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/resolver/ExpressionResolverTest.java
index b89d09b..a2044df 100644
--- a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/resolver/ExpressionResolverTest.java
+++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/resolver/ExpressionResolverTest.java
@@ -230,6 +230,24 @@ public class ExpressionResolverTest {
                                                         DataTypes.INT()
                                                                 .notNull()
                                                                 .bridgedTo(int.class))),
+                                        DataTypes.INT().notNull().bridgedTo(int.class))),
+                TestSpec.test("Star expression as parameter of user-defined func")
+                        .inputSchemas(
+                                TableSchema.builder()
+                                        .field("f0", DataTypes.INT())
+                                        .field("f1", DataTypes.STRING())
+                                        .build())
+                        .lookupFunction("func", new ScalarFunc())
+                        .select(call("func", $("*")))
+                        .equalTo(
+                                new CallExpression(
+                                        FunctionIdentifier.of("func"),
+                                        new ScalarFunc(),
+                                        Arrays.asList(
+                                                new FieldReferenceExpression(
+                                                        "f0", DataTypes.INT(), 0, 0),
+                                                new FieldReferenceExpression(
+                                                        "f1", DataTypes.STRING(), 0, 1)),
                                         DataTypes.INT().notNull().bridgedTo(int.class))));
     }
 
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/CalcITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/CalcITCase.scala
index aab0214..76f4687 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/CalcITCase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/CalcITCase.scala
@@ -19,6 +19,7 @@
 package org.apache.flink.table.planner.runtime.stream.table
 
 import org.apache.flink.api.scala._
+import org.apache.flink.table.annotation.{DataTypeHint, InputGroup}
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.functions.ScalarFunction
@@ -34,6 +35,7 @@ import org.junit._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 
+import scala.annotation.varargs
 import scala.collection.{Seq, mutable}
 
 @RunWith(classOf[Parameterized])
@@ -373,6 +375,41 @@ class CalcITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode
   }
 
   @Test
+  def testCallFunctionWithStarArgument(): Unit = {
+    val table = tEnv.fromDataStream(env.fromCollection(Seq(
+      ("Foo", 0, 3),
+      ("Bar", 1, 4),
+      ("Error", -1, 2),
+      ("Example", 3, 6)
+    )), '_s, '_b, '_e).where(ValidSubStringFilter('*)).select(SubstringFunc('*))
+
+    val sink = new TestingAppendSink
+    table.toAppendStream[Row].addSink(sink)
+    env.execute()
+
+    val expected = List("Foo", "mpl")
+    assertEquals(expected.sorted, sink.getAppendResults.sorted)
+  }
+
+  @SerialVersionUID(1L)
+  object ValidSubStringFilter extends ScalarFunction {
+    @varargs
+    def eval(@DataTypeHint(inputGroup = InputGroup.ANY) row: AnyRef*): Boolean = {
+      val str = row(0).asInstanceOf[String]
+      val begin = row(1).asInstanceOf[Int]
+      val end = row(2).asInstanceOf[Int]
+      begin >= 0 && begin <= end && str.length() >= end
+    }
+  }
+
+  @SerialVersionUID(1L)
+  object SubstringFunc extends ScalarFunction {
+    def eval(str: String, begin: Int, end: Int): String = {
+      str.substring(begin, end)
+    }
+  }
+
+  @Test
   def testMapType(): Unit = {
     val ds = env.fromCollection(tupleData3).toTable(tEnv).select(map('_1, '_3))
 
@@ -462,6 +499,24 @@ class CalcITCase(mode: StateBackendMode) extends StreamingWithStateTestBase(mode
     assertEquals(expected.sorted, sink.getAppendResults.sorted)
   }
 
+  @Test
+  def testMapWithStarArgument(): Unit = {
+    val ds = env.fromCollection(smallTupleData3).toTable(tEnv, 'a, 'b, 'c)
+      .map(Func23('*)).as("a", "b", "c", "d")
+      .map(Func24('*)).as("a", "b", "c", "d")
+      .map(Func1('b))
+
+    val sink = new TestingAppendSink
+    ds.toAppendStream[Row].addSink(sink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "3",
+      "4",
+      "5")
+    assertEquals(expected.sorted, sink.getAppendResults.sorted)
+  }
+
   @Ignore("Will be open when FLINK-10834 has been fixed.")
   @Test
   def testNonDeterministic(): Unit = {