You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2020/10/12 09:55:14 UTC

[GitHub] [flink] dianfu commented on a change in pull request #13507: [FLINK-19231][python] Support ListState and ListView for Python UDAF.

dianfu commented on a change in pull request #13507:
URL: https://github.com/apache/flink/pull/13507#discussion_r503156160



##########
File path: flink-python/pyflink/fn_execution/aggregate.py
##########
@@ -16,13 +16,16 @@
 # limitations under the License.
 ################################################################################
 from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Dict

Review comment:
       Unused import

##########
File path: flink-python/pyflink/fn_execution/aggregate.py
##########
@@ -16,13 +16,16 @@
 # limitations under the License.
 ################################################################################
 from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Dict
 
-from apache_beam.coders import PickleCoder
+from apache_beam.coders import PickleCoder, Coder
 
 from pyflink.common import Row, RowKind
+from pyflink.common.state import ListState
+from pyflink.fn_execution.coders import from_proto
 from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
 from pyflink.table import AggregateFunction, FunctionContext
+from pyflink.table.data_view import ListView, DataView

Review comment:
       Unused import

##########
File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
##########
@@ -65,18 +69,77 @@ trait CommonPythonAggregate extends CommonPythonBase {
     * For streaming execution we extract the PythonFunctionInfo from AggregateInfo.
     */
   protected def extractPythonAggregateFunctionInfosFromAggregateInfo(
-      pythonAggregateInfo: AggregateInfo): PythonFunctionInfo = {
+      aggIndex: Int,
+      pythonAggregateInfo: AggregateInfo): (PythonFunctionInfo, Array[DataViewSpec]) = {
     pythonAggregateInfo.function match {
       case function: PythonFunction =>
-        new PythonFunctionInfo(
-          function,
-          pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef]))
+        (
+          new PythonFunctionInfo(
+            function,
+            pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef])),
+          extractDataViewSpecs(
+            aggIndex,
+            function.asInstanceOf[PythonAggregateFunction].getTypeInference(null)
+              .getAccumulatorTypeStrategy.get()
+              .inferType(null).get())
+        )
       case _: Count1AggFunction =>
         // The count star will be treated specially in Python UDF worker
-        PythonFunctionInfo.DUMMY_PLACEHOLDER
+        (PythonFunctionInfo.DUMMY_PLACEHOLDER, Array())
       case _ =>
         throw new TableException(
           "Unsupported python aggregate function: " + pythonAggregateInfo.function)
     }
   }
+
+  protected def extractDataViewSpecs(
+      index: Int,
+      accType: DataType): Array[DataViewSpec] = {
+    if (!accType.isInstanceOf[FieldsDataType]) {
+      return Array()
+    }
+
+    val compositeAccType = accType.asInstanceOf[FieldsDataType]
+
+    def includesDataView(fdt: FieldsDataType): Boolean = {
+      (0 until fdt.getChildren.size()).exists(i =>
+        fdt.getChildren.get(i).getLogicalType match {
+          case row: RowType =>

Review comment:
       ```suggestion
             case _: RowType =>
   ```

##########
File path: flink-python/pyflink/fn_execution/beam/beam_operations_slow.py
##########
@@ -304,9 +306,6 @@ def finish(self):
 
     def reset(self):
         super().reset()
-        if self.keyed_state_backend:

Review comment:
       OK. We could just remove the reset method as it does nothing right now.

##########
File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
##########
@@ -65,18 +69,77 @@ trait CommonPythonAggregate extends CommonPythonBase {
     * For streaming execution we extract the PythonFunctionInfo from AggregateInfo.
     */
   protected def extractPythonAggregateFunctionInfosFromAggregateInfo(
-      pythonAggregateInfo: AggregateInfo): PythonFunctionInfo = {
+      aggIndex: Int,
+      pythonAggregateInfo: AggregateInfo): (PythonFunctionInfo, Array[DataViewSpec]) = {
     pythonAggregateInfo.function match {
       case function: PythonFunction =>
-        new PythonFunctionInfo(
-          function,
-          pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef]))
+        (
+          new PythonFunctionInfo(
+            function,
+            pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef])),
+          extractDataViewSpecs(
+            aggIndex,
+            function.asInstanceOf[PythonAggregateFunction].getTypeInference(null)
+              .getAccumulatorTypeStrategy.get()
+              .inferType(null).get())
+        )
       case _: Count1AggFunction =>
         // The count star will be treated specially in Python UDF worker
-        PythonFunctionInfo.DUMMY_PLACEHOLDER
+        (PythonFunctionInfo.DUMMY_PLACEHOLDER, Array())
       case _ =>
         throw new TableException(
           "Unsupported python aggregate function: " + pythonAggregateInfo.function)
     }
   }
+
+  protected def extractDataViewSpecs(
+      index: Int,
+      accType: DataType): Array[DataViewSpec] = {
+    if (!accType.isInstanceOf[FieldsDataType]) {
+      return Array()
+    }
+
+    val compositeAccType = accType.asInstanceOf[FieldsDataType]
+
+    def includesDataView(fdt: FieldsDataType): Boolean = {
+      (0 until fdt.getChildren.size()).exists(i =>
+        fdt.getChildren.get(i).getLogicalType match {
+          case row: RowType =>
+            includesDataView(fdt.getChildren.get(i).asInstanceOf[FieldsDataType])
+          case structed: StructuredType =>

Review comment:
       ```suggestion
             case structuredType: StructuredType =>
   ```

##########
File path: flink-python/pyflink/fn_execution/aggregate.py
##########
@@ -34,6 +37,75 @@ def join_row(left: Row, right: Row):
     return Row(*fields)
 
 
+def extract_data_view_specs(udfs):
+    extracted_udf_data_view_specs = []
+    for udf in udfs:
+        udf_data_view_specs_proto = udf.specs
+        if udf_data_view_specs_proto is None:
+            extracted_udf_data_view_specs.append([])
+        extracted_specs = []
+        for spec_proto in udf_data_view_specs_proto:
+            state_id = spec_proto.name
+            field_index = spec_proto.field_index
+            if spec_proto.list_view is not None:
+                element_coder = from_proto(spec_proto.list_view.element_type)
+                extracted_specs.append(ListViewSpec(state_id, field_index, element_coder))
+            elif spec_proto.map_view is not None:
+                key_coder = from_proto(spec_proto.map_view.key_type)
+                value_coder = from_proto(spec_proto.map_view.value_type)
+                extracted_specs.append(MapViewSpec(state_id, field_index, key_coder, value_coder))
+            else:
+                raise Exception("Unsupported data view spec type: " + spec_proto.type)
+        extracted_udf_data_view_specs.append(extracted_specs)
+    if all([len(i) == 0 for i in extracted_udf_data_view_specs]):
+        return []
+    return extracted_udf_data_view_specs
+
+
+class StateListView(ListView):
+
+    def __init__(self, list_state):
+        super().__init__()
+        self._list_state = list_state  # type: ListState

Review comment:
       Why not use typehint directly?

##########
File path: flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
##########
@@ -65,18 +69,77 @@ trait CommonPythonAggregate extends CommonPythonBase {
     * For streaming execution we extract the PythonFunctionInfo from AggregateInfo.
     */
   protected def extractPythonAggregateFunctionInfosFromAggregateInfo(
-      pythonAggregateInfo: AggregateInfo): PythonFunctionInfo = {
+      aggIndex: Int,
+      pythonAggregateInfo: AggregateInfo): (PythonFunctionInfo, Array[DataViewSpec]) = {
     pythonAggregateInfo.function match {
       case function: PythonFunction =>
-        new PythonFunctionInfo(
-          function,
-          pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef]))
+        (
+          new PythonFunctionInfo(
+            function,
+            pythonAggregateInfo.argIndexes.map(_.asInstanceOf[AnyRef])),
+          extractDataViewSpecs(
+            aggIndex,
+            function.asInstanceOf[PythonAggregateFunction].getTypeInference(null)
+              .getAccumulatorTypeStrategy.get()
+              .inferType(null).get())
+        )
       case _: Count1AggFunction =>
         // The count star will be treated specially in Python UDF worker
-        PythonFunctionInfo.DUMMY_PLACEHOLDER
+        (PythonFunctionInfo.DUMMY_PLACEHOLDER, Array())
       case _ =>
         throw new TableException(
           "Unsupported python aggregate function: " + pythonAggregateInfo.function)
     }
   }
+
+  protected def extractDataViewSpecs(
+      index: Int,
+      accType: DataType): Array[DataViewSpec] = {
+    if (!accType.isInstanceOf[FieldsDataType]) {
+      return Array()
+    }
+
+    val compositeAccType = accType.asInstanceOf[FieldsDataType]
+
+    def includesDataView(fdt: FieldsDataType): Boolean = {
+      (0 until fdt.getChildren.size()).exists(i =>
+        fdt.getChildren.get(i).getLogicalType match {
+          case row: RowType =>
+            includesDataView(fdt.getChildren.get(i).asInstanceOf[FieldsDataType])
+          case structed: StructuredType =>
+            classOf[DataView].isAssignableFrom(structed.getImplementationClass.get())
+          case _ => false
+        }
+      )
+    }
+
+    if (includesDataView(compositeAccType)) {
+      compositeAccType.getLogicalType match {
+        case rowType: RowType =>
+            (0 until compositeAccType.getChildren.size()).flatMap(i => {
+              compositeAccType.getChildren.get(i).getLogicalType match {
+                case _: RowType if includesDataView(
+                  compositeAccType.getChildren.get(i).asInstanceOf[FieldsDataType]) =>
+                  throw new TableException(
+                    "For Python AggregateFunction DataView only supported at first " +

Review comment:
       ```suggestion
                       "For Python AggregateFunction, DataView cannot be used in the nested columns of the accumulator." +
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org