You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/02/26 06:50:27 UTC

[flink] branch master updated: [FLINK-16251][python] Optimize the cost of function call in ScalarFunctionOpertation

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ada299  [FLINK-16251][python] Optimize the cost of function call in ScalarFunctionOpertation
1ada299 is described below

commit 1ada2997254d08b0baa15484d68b93250098289c
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Wed Feb 26 14:50:06 2020 +0800

    [FLINK-16251][python] Optimize the cost of function call in ScalarFunctionOpertation
---
 flink-python/pyflink/fn_execution/operations.py | 295 ++++++++----------------
 1 file changed, 95 insertions(+), 200 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/operations.py b/flink-python/pyflink/fn_execution/operations.py
index a9bdc13..d545bc8 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -17,7 +17,6 @@
 ################################################################################
 
 import datetime
-from abc import abstractmethod, ABCMeta
 
 import cloudpickle
 from apache_beam.runners.common import _OutputProcessor
@@ -31,200 +30,6 @@ from pyflink.serializers import PickleSerializer
 SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
 
 
-class InputGetter(object):
-    """
-    Base class for get an input argument for a :class:`UserDefinedFunction`.
-    """
-    __metaclass__ = ABCMeta
-
-    def open(self):
-        pass
-
-    def close(self):
-        pass
-
-    @abstractmethod
-    def get(self, value):
-        pass
-
-
-class OffsetInputGetter(InputGetter):
-    """
-    InputGetter for the input argument which is a column of the input row.
-
-    :param input_offset: the offset of the column in the input row
-    """
-
-    def __init__(self, input_offset):
-        self.input_offset = input_offset
-
-    def get(self, value):
-        return value[self.input_offset]
-
-
-class ScalarFunctionInputGetter(InputGetter):
-    """
-    InputGetter for the input argument which is a Python :class:`ScalarFunction`. This is used for
-    chaining Python functions.
-
-    :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
-    """
-
-    def __init__(self, scalar_function_proto):
-        self.scalar_function_invoker = create_scalar_function_invoker(scalar_function_proto)
-
-    def open(self):
-        self.scalar_function_invoker.invoke_open()
-
-    def close(self):
-        self.scalar_function_invoker.invoke_close()
-
-    def get(self, value):
-        return self.scalar_function_invoker.invoke_eval(value)
-
-
-class ConstantInputGetter(InputGetter):
-    """
-    InputGetter for the input argument which is a constant value.
-
-    :param constant_value: the constant value of the column
-    """
-
-    def __init__(self, constant_value):
-        j_type = constant_value[0]
-        serializer = PickleSerializer()
-        pickled_data = serializer.loads(constant_value[1:])
-        # the type set contains
-        # TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
-        # the pickled_data doesn't need to transfer to anther python object
-        if j_type == 0:
-            self._constant_value = pickled_data
-        # the type is DATE
-        elif j_type == 1:
-            self._constant_value = \
-                datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
-        # the type is TIME
-        elif j_type == 2:
-            seconds, milliseconds = divmod(pickled_data, 1000)
-            minutes, seconds = divmod(seconds, 60)
-            hours, minutes = divmod(minutes, 60)
-            self._constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
-        # the type is TIMESTAMP
-        elif j_type == 3:
-            self._constant_value = \
-                datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
-                + datetime.timedelta(milliseconds=pickled_data)
-        else:
-            raise Exception("Unknown type %s, should never happen" % str(j_type))
-
-    def get(self, value):
-        return self._constant_value
-
-
-class ScalarFunctionInvoker(object):
-    """
-    An abstraction that can be used to execute :class:`ScalarFunction` methods.
-
-    A ScalarFunctionInvoker describes a particular way for invoking methods of a
-    :class:`ScalarFunction`.
-
-    :param scalar_function: the :class:`ScalarFunction` to execute
-    :param inputs: the input arguments for the :class:`ScalarFunction`
-    """
-
-    def __init__(self, scalar_function, inputs):
-        self.scalar_function = scalar_function
-        self.input_getters = []
-        for input in inputs:
-            if input.HasField("udf"):
-                # for chaining Python UDF input: the input argument is a Python ScalarFunction
-                self.input_getters.append(ScalarFunctionInputGetter(input.udf))
-            elif input.HasField("inputOffset"):
-                # the input argument is a column of the input row
-                self.input_getters.append(OffsetInputGetter(input.inputOffset))
-            else:
-                # the input argument is a constant value
-                self.input_getters.append(ConstantInputGetter(input.inputConstant))
-
-    def invoke_open(self):
-        """
-        Invokes the ScalarFunction.open() function.
-        """
-        for input_getter in self.input_getters:
-            input_getter.open()
-        # set the FunctionContext to None for now
-        self.scalar_function.open(None)
-
-    def invoke_close(self):
-        """
-        Invokes the ScalarFunction.close() function.
-        """
-        for input_getter in self.input_getters:
-            input_getter.close()
-        self.scalar_function.close()
-
-    def invoke_eval(self, value):
-        """
-        Invokes the ScalarFunction.eval() function.
-
-        :param value: the input element for which eval() method should be invoked
-        """
-        args = [input_getter.get(value) for input_getter in self.input_getters]
-        return self.scalar_function.eval(*args)
-
-
-def create_scalar_function_invoker(scalar_function_proto):
-    """
-    Creates :class:`ScalarFunctionInvoker` from the proto representation of a
-    :class:`ScalarFunction`.
-
-    :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
-    :return: :class:`ScalarFunctionInvoker`.
-    """
-    scalar_function = cloudpickle.loads(scalar_function_proto.payload)
-    return ScalarFunctionInvoker(scalar_function, scalar_function_proto.inputs)
-
-
-class ScalarFunctionRunner(object):
-    """
-    The runner which is responsible for executing the scalar functions and send the
-    execution results back to the remote Java operator.
-
-    :param udfs_proto: protocol representation for the scalar functions to execute
-    """
-
-    def __init__(self, udfs_proto):
-        self.scalar_function_invokers = [create_scalar_function_invoker(f) for f in
-                                         udfs_proto]
-
-    def setup(self, main_receivers):
-        """
-        Set up the ScalarFunctionRunner.
-
-        :param main_receivers: Receiver objects which is responsible for sending the execution
-                               results back the the remote Java operator
-        """
-        self.output_processor = _OutputProcessor(
-            window_fn=None,
-            main_receivers=main_receivers,
-            tagged_receivers=None,
-            per_element_output_counter=None)
-
-    def open(self):
-        for invoker in self.scalar_function_invokers:
-            invoker.invoke_open()
-
-    def close(self):
-        for invoker in self.scalar_function_invokers:
-            invoker.invoke_close()
-
-    def process(self, windowed_value):
-        results = [invoker.invoke_eval(windowed_value.value) for invoker in
-                   self.scalar_function_invokers]
-        # send the execution results back
-        self.output_processor.process_outputs(windowed_value, [results])
-
-
 class ScalarFunctionOperation(Operation):
     """
     An operation that will execute ScalarFunctions for each input element.
@@ -236,19 +41,26 @@ class ScalarFunctionOperation(Operation):
             for consumer in op_consumers:
                 self.add_receiver(consumer, 0)
 
-        self.scalar_function_runner = ScalarFunctionRunner(self.spec.serialized_fn)
-        self.scalar_function_runner.open()
+        self.variable_dict = {}
+        self.scalar_funcs = []
+        self.func = self._generate_func(self.spec.serialized_fn)
+        for scalar_func in self.scalar_funcs:
+            scalar_func.open(None)
 
     def setup(self):
         super(ScalarFunctionOperation, self).setup()
-        self.scalar_function_runner.setup(self.receivers[0])
+        self.output_processor = _OutputProcessor(
+            window_fn=None,
+            main_receivers=self.receivers[0],
+            tagged_receivers=None,
+            per_element_output_counter=None)
 
     def start(self):
         with self.scoped_start_state:
             super(ScalarFunctionOperation, self).start()
 
     def process(self, o):
-        self.scalar_function_runner.process(o)
+        self.output_processor.process_outputs(o, [self.func(o.value)])
 
     def finish(self):
         super(ScalarFunctionOperation, self).finish()
@@ -260,7 +72,8 @@ class ScalarFunctionOperation(Operation):
         super(ScalarFunctionOperation, self).reset()
 
     def teardown(self):
-        self.scalar_function_runner.close()
+        for scalar_func in self.scalar_funcs:
+            scalar_func.close(None)
 
     def progress_metrics(self):
         metrics = super(ScalarFunctionOperation, self).progress_metrics()
@@ -271,6 +84,88 @@ class ScalarFunctionOperation(Operation):
             str(tag)] = receiver.opcounter.element_counter.value()
         return metrics
 
+    def _generate_func(self, udfs):
+        """
+        Generates a lambda function based on udfs.
+        :param udfs: a list of the proto representation of the Python :class:`ScalarFunction`
+        :return: the generated lambda function
+        """
+        scalar_functions = [self._extract_scalar_function(udf) for udf in udfs]
+        return eval('lambda value: [%s]' % ','.join(scalar_functions), self.variable_dict)
+
+    def _extract_scalar_function(self, scalar_function_proto):
+        """
+        Extracts scalar_function from the proto representation of a
+        :class:`ScalarFunction`.
+
+        :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
+        """
+        def _next_func_num():
+            if not hasattr(self, "_func_num"):
+                self._func_num = 0
+            else:
+                self._func_num += 1
+            return self._func_num
+
+        scalar_func = cloudpickle.loads(scalar_function_proto.payload)
+        func_name = 'f%s' % _next_func_num()
+        self.variable_dict[func_name] = scalar_func.eval
+        self.scalar_funcs.append(scalar_func)
+        func_args = self._extract_scalar_function_args(scalar_function_proto.inputs)
+        return "%s(%s)" % (func_name, func_args)
+
+    def _extract_scalar_function_args(self, args):
+        args_str = []
+        for arg in args:
+            if arg.HasField("udf"):
+                # for chaining Python UDF input: the input argument is a Python ScalarFunction
+                args_str.append(self._extract_scalar_function(arg.udf))
+            elif arg.HasField("inputOffset"):
+                # the input argument is a column of the input row
+                args_str.append("value[%s]" % arg.inputOffset)
+            else:
+                # the input argument is a constant value
+                args_str.append(self._parse_constant_value(arg.inputConstant))
+        return ",".join(args_str)
+
+    def _parse_constant_value(self, constant_value):
+        j_type = constant_value[0]
+        serializer = PickleSerializer()
+        pickled_data = serializer.loads(constant_value[1:])
+        # the type set contains
+        # TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
+        # the pickled_data doesn't need to transfer to anther python object
+        if j_type == 0:
+            parsed_constant_value = pickled_data
+        # the type is DATE
+        elif j_type == 1:
+            parsed_constant_value = \
+                datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
+        # the type is TIME
+        elif j_type == 2:
+            seconds, milliseconds = divmod(pickled_data, 1000)
+            minutes, seconds = divmod(seconds, 60)
+            hours, minutes = divmod(minutes, 60)
+            parsed_constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
+        # the type is TIMESTAMP
+        elif j_type == 3:
+            parsed_constant_value = \
+                datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
+                + datetime.timedelta(milliseconds=pickled_data)
+        else:
+            raise Exception("Unknown type %s, should never happen" % str(j_type))
+
+        def _next_constant_num():
+            if not hasattr(self, "_constant_num"):
+                self._constant_num = 0
+            else:
+                self._constant_num += 1
+            return self._constant_num
+
+        constant_value_name = 'c%s' % _next_constant_num()
+        self.variable_dict[constant_value_name] = parsed_constant_value
+        return constant_value_name
+
 
 @bundle_processor.BeamTransformFactory.register_urn(
     SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions)