You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/02/26 06:50:27 UTC
[flink] branch master updated: [FLINK-16251][python] Optimize the
cost of function call in ScalarFunctionOpertation
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 1ada299 [FLINK-16251][python] Optimize the cost of function call in ScalarFunctionOpertation
1ada299 is described below
commit 1ada2997254d08b0baa15484d68b93250098289c
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Wed Feb 26 14:50:06 2020 +0800
[FLINK-16251][python] Optimize the cost of function call in ScalarFunctionOpertation
---
flink-python/pyflink/fn_execution/operations.py | 295 ++++++++----------------
1 file changed, 95 insertions(+), 200 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/operations.py b/flink-python/pyflink/fn_execution/operations.py
index a9bdc13..d545bc8 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -17,7 +17,6 @@
################################################################################
import datetime
-from abc import abstractmethod, ABCMeta
import cloudpickle
from apache_beam.runners.common import _OutputProcessor
@@ -31,200 +30,6 @@ from pyflink.serializers import PickleSerializer
SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
-class InputGetter(object):
- """
- Base class for get an input argument for a :class:`UserDefinedFunction`.
- """
- __metaclass__ = ABCMeta
-
- def open(self):
- pass
-
- def close(self):
- pass
-
- @abstractmethod
- def get(self, value):
- pass
-
-
-class OffsetInputGetter(InputGetter):
- """
- InputGetter for the input argument which is a column of the input row.
-
- :param input_offset: the offset of the column in the input row
- """
-
- def __init__(self, input_offset):
- self.input_offset = input_offset
-
- def get(self, value):
- return value[self.input_offset]
-
-
-class ScalarFunctionInputGetter(InputGetter):
- """
- InputGetter for the input argument which is a Python :class:`ScalarFunction`. This is used for
- chaining Python functions.
-
- :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
- """
-
- def __init__(self, scalar_function_proto):
- self.scalar_function_invoker = create_scalar_function_invoker(scalar_function_proto)
-
- def open(self):
- self.scalar_function_invoker.invoke_open()
-
- def close(self):
- self.scalar_function_invoker.invoke_close()
-
- def get(self, value):
- return self.scalar_function_invoker.invoke_eval(value)
-
-
-class ConstantInputGetter(InputGetter):
- """
- InputGetter for the input argument which is a constant value.
-
- :param constant_value: the constant value of the column
- """
-
- def __init__(self, constant_value):
- j_type = constant_value[0]
- serializer = PickleSerializer()
- pickled_data = serializer.loads(constant_value[1:])
- # the type set contains
- # TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
- # the pickled_data doesn't need to transfer to anther python object
- if j_type == 0:
- self._constant_value = pickled_data
- # the type is DATE
- elif j_type == 1:
- self._constant_value = \
- datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
- # the type is TIME
- elif j_type == 2:
- seconds, milliseconds = divmod(pickled_data, 1000)
- minutes, seconds = divmod(seconds, 60)
- hours, minutes = divmod(minutes, 60)
- self._constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
- # the type is TIMESTAMP
- elif j_type == 3:
- self._constant_value = \
- datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
- + datetime.timedelta(milliseconds=pickled_data)
- else:
- raise Exception("Unknown type %s, should never happen" % str(j_type))
-
- def get(self, value):
- return self._constant_value
-
-
-class ScalarFunctionInvoker(object):
- """
- An abstraction that can be used to execute :class:`ScalarFunction` methods.
-
- A ScalarFunctionInvoker describes a particular way for invoking methods of a
- :class:`ScalarFunction`.
-
- :param scalar_function: the :class:`ScalarFunction` to execute
- :param inputs: the input arguments for the :class:`ScalarFunction`
- """
-
- def __init__(self, scalar_function, inputs):
- self.scalar_function = scalar_function
- self.input_getters = []
- for input in inputs:
- if input.HasField("udf"):
- # for chaining Python UDF input: the input argument is a Python ScalarFunction
- self.input_getters.append(ScalarFunctionInputGetter(input.udf))
- elif input.HasField("inputOffset"):
- # the input argument is a column of the input row
- self.input_getters.append(OffsetInputGetter(input.inputOffset))
- else:
- # the input argument is a constant value
- self.input_getters.append(ConstantInputGetter(input.inputConstant))
-
- def invoke_open(self):
- """
- Invokes the ScalarFunction.open() function.
- """
- for input_getter in self.input_getters:
- input_getter.open()
- # set the FunctionContext to None for now
- self.scalar_function.open(None)
-
- def invoke_close(self):
- """
- Invokes the ScalarFunction.close() function.
- """
- for input_getter in self.input_getters:
- input_getter.close()
- self.scalar_function.close()
-
- def invoke_eval(self, value):
- """
- Invokes the ScalarFunction.eval() function.
-
- :param value: the input element for which eval() method should be invoked
- """
- args = [input_getter.get(value) for input_getter in self.input_getters]
- return self.scalar_function.eval(*args)
-
-
-def create_scalar_function_invoker(scalar_function_proto):
- """
- Creates :class:`ScalarFunctionInvoker` from the proto representation of a
- :class:`ScalarFunction`.
-
- :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
- :return: :class:`ScalarFunctionInvoker`.
- """
- scalar_function = cloudpickle.loads(scalar_function_proto.payload)
- return ScalarFunctionInvoker(scalar_function, scalar_function_proto.inputs)
-
-
-class ScalarFunctionRunner(object):
- """
- The runner which is responsible for executing the scalar functions and send the
- execution results back to the remote Java operator.
-
- :param udfs_proto: protocol representation for the scalar functions to execute
- """
-
- def __init__(self, udfs_proto):
- self.scalar_function_invokers = [create_scalar_function_invoker(f) for f in
- udfs_proto]
-
- def setup(self, main_receivers):
- """
- Set up the ScalarFunctionRunner.
-
- :param main_receivers: Receiver objects which is responsible for sending the execution
- results back the the remote Java operator
- """
- self.output_processor = _OutputProcessor(
- window_fn=None,
- main_receivers=main_receivers,
- tagged_receivers=None,
- per_element_output_counter=None)
-
- def open(self):
- for invoker in self.scalar_function_invokers:
- invoker.invoke_open()
-
- def close(self):
- for invoker in self.scalar_function_invokers:
- invoker.invoke_close()
-
- def process(self, windowed_value):
- results = [invoker.invoke_eval(windowed_value.value) for invoker in
- self.scalar_function_invokers]
- # send the execution results back
- self.output_processor.process_outputs(windowed_value, [results])
-
-
class ScalarFunctionOperation(Operation):
"""
An operation that will execute ScalarFunctions for each input element.
@@ -236,19 +41,26 @@ class ScalarFunctionOperation(Operation):
for consumer in op_consumers:
self.add_receiver(consumer, 0)
- self.scalar_function_runner = ScalarFunctionRunner(self.spec.serialized_fn)
- self.scalar_function_runner.open()
+ self.variable_dict = {}
+ self.scalar_funcs = []
+ self.func = self._generate_func(self.spec.serialized_fn)
+ for scalar_func in self.scalar_funcs:
+ scalar_func.open(None)
def setup(self):
super(ScalarFunctionOperation, self).setup()
- self.scalar_function_runner.setup(self.receivers[0])
+ self.output_processor = _OutputProcessor(
+ window_fn=None,
+ main_receivers=self.receivers[0],
+ tagged_receivers=None,
+ per_element_output_counter=None)
def start(self):
with self.scoped_start_state:
super(ScalarFunctionOperation, self).start()
def process(self, o):
- self.scalar_function_runner.process(o)
+ self.output_processor.process_outputs(o, [self.func(o.value)])
def finish(self):
super(ScalarFunctionOperation, self).finish()
@@ -260,7 +72,8 @@ class ScalarFunctionOperation(Operation):
super(ScalarFunctionOperation, self).reset()
def teardown(self):
- self.scalar_function_runner.close()
+ for scalar_func in self.scalar_funcs:
+ scalar_func.close(None)
def progress_metrics(self):
metrics = super(ScalarFunctionOperation, self).progress_metrics()
@@ -271,6 +84,88 @@ class ScalarFunctionOperation(Operation):
str(tag)] = receiver.opcounter.element_counter.value()
return metrics
+ def _generate_func(self, udfs):
+ """
+ Generates a lambda function based on udfs.
+ :param udfs: a list of the proto representation of the Python :class:`ScalarFunction`
+ :return: the generated lambda function
+ """
+ scalar_functions = [self._extract_scalar_function(udf) for udf in udfs]
+ return eval('lambda value: [%s]' % ','.join(scalar_functions), self.variable_dict)
+
+ def _extract_scalar_function(self, scalar_function_proto):
+ """
+ Extracts scalar_function from the proto representation of a
+ :class:`ScalarFunction`.
+
+ :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
+ """
+ def _next_func_num():
+ if not hasattr(self, "_func_num"):
+ self._func_num = 0
+ else:
+ self._func_num += 1
+ return self._func_num
+
+ scalar_func = cloudpickle.loads(scalar_function_proto.payload)
+ func_name = 'f%s' % _next_func_num()
+ self.variable_dict[func_name] = scalar_func.eval
+ self.scalar_funcs.append(scalar_func)
+ func_args = self._extract_scalar_function_args(scalar_function_proto.inputs)
+ return "%s(%s)" % (func_name, func_args)
+
+ def _extract_scalar_function_args(self, args):
+ args_str = []
+ for arg in args:
+ if arg.HasField("udf"):
+ # for chaining Python UDF input: the input argument is a Python ScalarFunction
+ args_str.append(self._extract_scalar_function(arg.udf))
+ elif arg.HasField("inputOffset"):
+ # the input argument is a column of the input row
+ args_str.append("value[%s]" % arg.inputOffset)
+ else:
+ # the input argument is a constant value
+ args_str.append(self._parse_constant_value(arg.inputConstant))
+ return ",".join(args_str)
+
+ def _parse_constant_value(self, constant_value):
+ j_type = constant_value[0]
+ serializer = PickleSerializer()
+ pickled_data = serializer.loads(constant_value[1:])
+ # the type set contains
+ # TINYINT,SMALLINT,INTEGER,BIGINT,FLOAT,DOUBLE,DECIMAL,CHAR,VARCHAR,NULL,BOOLEAN
+ # the pickled_data doesn't need to transfer to anther python object
+ if j_type == 0:
+ parsed_constant_value = pickled_data
+ # the type is DATE
+ elif j_type == 1:
+ parsed_constant_value = \
+ datetime.date(year=1970, month=1, day=1) + datetime.timedelta(days=pickled_data)
+ # the type is TIME
+ elif j_type == 2:
+ seconds, milliseconds = divmod(pickled_data, 1000)
+ minutes, seconds = divmod(seconds, 60)
+ hours, minutes = divmod(minutes, 60)
+ parsed_constant_value = datetime.time(hours, minutes, seconds, milliseconds * 1000)
+ # the type is TIMESTAMP
+ elif j_type == 3:
+ parsed_constant_value = \
+ datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) \
+ + datetime.timedelta(milliseconds=pickled_data)
+ else:
+ raise Exception("Unknown type %s, should never happen" % str(j_type))
+
+ def _next_constant_num():
+ if not hasattr(self, "_constant_num"):
+ self._constant_num = 0
+ else:
+ self._constant_num += 1
+ return self._constant_num
+
+ constant_value_name = 'c%s' % _next_constant_num()
+ self.variable_dict[constant_value_name] = parsed_constant_value
+ return constant_value_name
+
@bundle_processor.BeamTransformFactory.register_urn(
SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions)