You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/09/18 21:17:26 UTC

incubator-airflow git commit: [AIRFLOW-1512] Add PythonVirtualenvOperator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 8f9bf94d8 -> 8e253c750


[AIRFLOW-1512] Add PythonVirtualenvOperator

Closes #2446 from saguziel/aguziel-virtualenv


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/8e253c75
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/8e253c75
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/8e253c75

Branch: refs/heads/master
Commit: 8e253c750d81da4e472049473767aaea0c504465
Parents: 8f9bf94
Author: Alex Guziel <al...@airbnb.com>
Authored: Mon Sep 18 14:17:21 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Mon Sep 18 14:17:21 2017 -0700

----------------------------------------------------------------------
 airflow/operators/python_operator.py        | 215 ++++++++++++++++++++++-
 docs/code.rst                               |   1 +
 tests/operators/test_virtualenv_operator.py | 188 ++++++++++++++++++++
 3 files changed, 403 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/8e253c75/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index 552996f..56837ec 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -11,9 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from builtins import str
+import dill
+import inspect
+import os
+import pickle
+import subprocess
+import sys
+import types
+
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, SkipMixin
 from airflow.utils.decorators import apply_defaults
+from airflow.utils.file import TemporaryDirectory
+
+from textwrap import dedent
 
 
 class PythonOperator(BaseOperator):
@@ -73,10 +86,13 @@ class PythonOperator(BaseOperator):
             context['templates_dict'] = self.templates_dict
             self.op_kwargs = context
 
-        return_value = self.python_callable(*self.op_args, **self.op_kwargs)
+        return_value = self.execute_callable()
         self.logger.info("Done. Returned value was: %s", return_value)
         return return_value
 
+    def execute_callable(self):
+        return self.python_callable(*self.op_args, **self.op_kwargs)
+
 
 class BranchPythonOperator(PythonOperator, SkipMixin):
     """
@@ -141,3 +157,200 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
             self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)
 
         self.logger.info("Done.")
+
+class PythonVirtualenvOperator(PythonOperator):
+    """
+    Allows one to run a function in a virtualenv that is created and destroyed
+    automatically (with certain caveats).
+
+    The function must be defined using def, and not be part of a class. All imports
+    must happen inside the function and no variables outside of the scope may be referenced.
+    A global scope variable named virtualenv_string_args will be available (populated by
+    string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
+    can use a return value.
+
+    Note that if your virtualenv runs in a different Python major version than Airflow,
+    you cannot use return values, op_args, or op_kwargs. You can use string_args though.
+
+    :param python_callable: A python function with no references to outside variables,
+        defined with def, which will be run in a virtualenv
+    :type python_callable: function
+    :param requirements: A list of requirements as specified in a pip install command
+    :type requirements: list(str)
+    :param python_version: The Python version to run the virtualenv with. Note that
+        both 2 and 2.7 are acceptable forms.
+    :type python_version: str
+    :param use_dill: Whether to use dill to serialize the args and result (pickle is default).
+        This allow more complex types but requires you to include dill in your requirements.
+    :type use_dill: bool
+    :param system_site_packages: Whether to include system_site_packages in your virtualenv.
+        See virtualenv documentation for more information.
+    :type system_site_packages: bool
+    :param op_args: A list of positional arguments to pass to python_callable.
+    :type op_kwargs: list
+    :param op_kwargs: A dict of keyword arguments to pass to python_callable.
+    :type op_kwargs: dict
+    :param string_args: Strings that are present in the global var virtualenv_string_args,
+        available to python_callable at runtime as a list(str). Note that args are split
+        by newline.
+    :type string_args: list(str)
+
+    """
+    def __init__(self, python_callable, requirements=None, python_version=None, use_dill=False,
+                 system_site_packages=True, op_args=None, op_kwargs=None, string_args=None,
+                 *args, **kwargs):
+        super(PythonVirtualenvOperator, self).__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            *args,
+            **kwargs)
+        self.requirements = requirements or []
+        self.string_args = string_args or []
+        self.python_version = python_version
+        self.use_dill = use_dill
+        self.system_site_packages = system_site_packages
+        # check that dill is present if needed
+        dill_in_requirements = map(lambda x: x.lower().startswith('dill'), self.requirements)
+        if (not system_site_packages) and use_dill and not any(dill_in_requirements):
+            raise AirflowException('If using dill, dill must be in the environment ' +
+                                   'either via system_site_packages or requirements')
+        # check that a function is passed, and that it is not a lambda
+        if (not isinstance(self.python_callable, types.FunctionType)
+                or self.python_callable.__name__ == (lambda x: 0).__name__):
+            raise AirflowException('{} only supports functions for python_callable arg',
+                                   self.__class__.__name__)
+        # check that args are passed iff python major version matches
+        if (python_version is not None
+                and str(python_version)[0] != str(sys.version_info[0])
+                and self._pass_op_args()):
+            raise AirflowException("Passing op_args or op_kwargs is not supported across "
+                                   "different Python major versions "
+                                   "for PythonVirtualenvOperator. Please use string_args.")
+
+    def execute_callable(self):
+        with TemporaryDirectory(prefix='venv') as tmp_dir:
+            # generate filenames
+            input_filename = os.path.join(tmp_dir, 'script.in')
+            output_filename = os.path.join(tmp_dir, 'script.out')
+            string_args_filename = os.path.join(tmp_dir, 'string_args.txt') 
+            script_filename = os.path.join(tmp_dir, 'script.py')
+
+            # set up virtualenv
+            self._execute_in_subprocess(self._generate_virtualenv_cmd(tmp_dir))
+            cmd = self._generate_pip_install_cmd(tmp_dir)
+            if cmd:
+                self._execute_in_subprocess(cmd)
+
+            self._write_args(input_filename)
+            self._write_script(script_filename)
+            self._write_string_args(string_args_filename)
+
+            # execute command in virtualenv
+            self._execute_in_subprocess(
+                self._generate_python_cmd(tmp_dir,
+                                          script_filename,
+                                          input_filename,
+                                          output_filename,
+                                          string_args_filename))
+            return self._read_result(output_filename)
+
+    def _pass_op_args(self):
+        # we should only pass op_args if any are given to us
+        return len(self.op_args) + len(self.op_kwargs) > 0
+
+    def _execute_in_subprocess(self, cmd):
+        try:
+            self.logger.info("Executing cmd\n{}".format(cmd))
+            output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+            if output:
+                self.logger.info("Got output\n{}".format(output))
+        except subprocess.CalledProcessError as e:
+            self.logger.info("Got error output\n{}".format(e.output))
+            raise
+
+    def _write_string_args(self, filename):
+        # writes string_args to a file, which are read line by line
+        with open(filename, 'w') as f:
+            f.write('\n'.join(map(str, self.string_args)))
+
+    def _write_args(self, input_filename):
+        # serialize args to file
+        if self._pass_op_args():
+            with open(input_filename, 'wb') as f:
+                arg_dict = ({'args': self.op_args, 'kwargs': self.op_kwargs})
+                if self.use_dill:
+                    dill.dump(arg_dict, f)
+                else:
+                    pickle.dump(arg_dict, f)
+
+    def _read_result(self, output_filename):
+        if os.stat(output_filename).st_size == 0:
+            return None
+        with open(output_filename, 'rb') as f:
+            try:
+                if self.use_dill:
+                    return dill.load(f)
+                else:
+                    return pickle.load(f)
+            except ValueError:
+                self.logger.error("Error deserializing result. Note that result deserialization "
+                              "is not supported across major Python versions.")
+                raise
+
+    def _write_script(self, script_filename):
+        with open(script_filename, 'w') as f:
+            python_code = self._generate_python_code()
+            self.logger.debug('Writing code to file\n{}'.format(python_code))
+            f.write(python_code)
+
+    def _generate_virtualenv_cmd(self, tmp_dir):
+        cmd = ['virtualenv', tmp_dir]
+        if self.system_site_packages:
+            cmd.append('--system-site-packages')
+        if self.python_version is not None:
+            cmd.append('--python=python{}'.format(self.python_version))
+        return cmd
+
+    def _generate_pip_install_cmd(self, tmp_dir):
+        if len(self.requirements) == 0:
+            return []
+        else:
+            # direct path alleviates need to activate
+            cmd = ['{}/bin/pip'.format(tmp_dir), 'install']
+            return cmd + self.requirements
+
+    def _generate_python_cmd(self, tmp_dir, script_filename, input_filename, output_filename, string_args_filename):
+        # direct path alleviates need to activate
+        return ['{}/bin/python'.format(tmp_dir), script_filename, input_filename, output_filename, string_args_filename]
+            
+    def _generate_python_code(self):
+        if self.use_dill:
+            pickling_library = 'dill'
+        else:
+            pickling_library = 'pickle'
+        fn = self.python_callable
+        # dont try to read pickle if we didnt pass anything
+        if self._pass_op_args():
+            load_args_line = 'with open(sys.argv[1], "rb") as f: arg_dict = {}.load(f)'.format(pickling_library)
+        else:
+            load_args_line = 'arg_dict = {"args": [], "kwargs": {}}'
+
+        # no indents in original code so we can accept any type of indents in the original function
+        # we deserialize args, call function, serialize result if necessary
+        return dedent("""\
+        import {pickling_library}
+        import sys
+        {load_args_code}
+        args = arg_dict["args"]
+        kwargs = arg_dict["kwargs"]
+        with open(sys.argv[3], 'r') as f: virtualenv_string_args = list(map(lambda x: x.strip(), list(f)))
+        {python_callable_lines}
+        res = {python_callable_name}(*args, **kwargs)
+        with open(sys.argv[2], 'wb') as f: res is not None and {pickling_library}.dump(res, f)
+        """).format(
+                load_args_code=load_args_line,
+                python_callable_lines=dedent(inspect.getsource(fn)),
+                python_callable_name=fn.__name__,
+                pickling_library=pickling_library)
+

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/8e253c75/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index a1980f2..4a6718f 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -72,6 +72,7 @@ Operator API
         PrestoIntervalCheckOperator,
         PrestoValueCheckOperator,
         PythonOperator,
+        PythonVirtualenvOperator,
         S3KeySensor,
         S3ToHiveTransfer,
         ShortCircuitOperator,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/8e253c75/tests/operators/test_virtualenv_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/test_virtualenv_operator.py b/tests/operators/test_virtualenv_operator.py
new file mode 100644
index 0000000..9231d39
--- /dev/null
+++ b/tests/operators/test_virtualenv_operator.py
@@ -0,0 +1,188 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function, unicode_literals
+
+import datetime
+import funcsigs
+import sys
+import unittest
+
+from subprocess import CalledProcessError
+
+from airflow import configuration, DAG
+from airflow.models import TaskInstance
+from airflow.operators.python_operator import PythonVirtualenvOperator
+from airflow.settings import Session
+from airflow.utils.state import State
+
+from airflow.exceptions import AirflowException
+import logging
+
+DEFAULT_DATE = datetime.datetime(2016, 1, 1)
+END_DATE = datetime.datetime(2016, 1, 2)
+INTERVAL = datetime.timedelta(hours=12)
+FROZEN_NOW = datetime.datetime(2016, 1, 2, 12, 1, 1)
+
+
+class TestPythonVirtualenvOperator(unittest.TestCase):
+
+    def setUp(self):
+        super(TestPythonVirtualenvOperator, self).setUp()
+        configuration.load_test_config()
+        self.dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE},
+            schedule_interval=INTERVAL)
+        self.addCleanup(self.dag.clear)
+
+    def _run_as_operator(self, fn, **kwargs):
+        task = PythonVirtualenvOperator(
+            python_callable=fn,
+            task_id='task',
+            dag=self.dag,
+            **kwargs)
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_dill_warning(self):
+        def f():
+            pass
+        with self.assertRaises(AirflowException):
+            PythonVirtualenvOperator(
+                python_callable=f,
+                task_id='task',
+                dag=self.dag,
+                use_dill=True,
+                system_site_packages=False)
+
+    def test_no_requirements(self):
+        """Tests that the python callable is invoked on task run."""
+        def f():
+            pass
+        self._run_as_operator(f)
+
+    def test_no_system_site_packages(self):
+        def f():
+            try:
+                import funcsigs
+            except ImportError:
+                return True
+            raise Exception
+        self._run_as_operator(f, system_site_packages=False, requirements=['dill'])
+
+    def test_system_site_packages(self):
+        def f():
+            import funcsigs
+        self._run_as_operator(f, requirements=['funcsigs'], system_site_packages=True)
+
+    def test_with_requirements_pinned(self):
+        self.assertNotEqual('0.4', funcsigs.__version__, 'Please update this string if this fails')
+        def f():
+            import funcsigs
+            if funcsigs.__version__ != '0.4':
+                raise Exception
+        self._run_as_operator(f, requirements=['funcsigs==0.4'])
+
+    def test_unpinned_requirements(self):
+        def f():
+            import funcsigs
+        self._run_as_operator(f, requirements=['funcsigs', 'dill'], system_site_packages=False)
+
+    def test_range_requirements(self):
+        def f():
+            import funcsigs
+        self._run_as_operator(f, requirements=['funcsigs>1.0', 'dill'], system_site_packages=False)
+
+    def test_fail(self):
+        def f():
+            raise Exception
+        with self.assertRaises(CalledProcessError):
+            self._run_as_operator(f)
+
+    def test_python_2(self):
+        def f():
+            {}.iteritems()
+        self._run_as_operator(f, python_version=2, requirements=['dill'])
+
+    def test_python_2_7(self):
+        def f():
+            {}.iteritems()
+            return True
+        self._run_as_operator(f, python_version='2.7', requirements=['dill'])
+
+    def test_python_3(self):
+        def f():
+            import sys
+            print(sys.version)
+            try:
+                {}.iteritems()
+            except AttributeError:
+                return
+            raise Exception
+        self._run_as_operator(f, python_version=3, use_dill=False, requirements=['dill'])
+
+    def _invert_python_major_version(self):
+        if sys.version_info[0] == 2:
+            return 3
+        else:
+            return 2
+
+    def test_wrong_python_op_args(self):
+        if sys.version_info[0] == 2:
+            version = 3
+        else:
+            version = 2
+        def f():
+            pass
+        with self.assertRaises(AirflowException):
+            self._run_as_operator(f, python_version=version, op_args=[1])
+
+    def test_without_dill(self):
+        def f(a):
+            return a
+        self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4])
+
+    def test_string_args(self):
+        def f():
+            print(virtualenv_string_args)
+            if virtualenv_string_args[0] != virtualenv_string_args[2]:
+                raise Exception
+        self._run_as_operator(f, python_version=self._invert_python_major_version(), string_args=[1,2,1])
+
+    def test_with_args(self):
+        def f(a, b, c=False, d=False):
+            if a==0 and b==1 and c and not d:
+                return True
+            else:
+                raise Exception
+        self._run_as_operator(f, op_args=[0, 1], op_kwargs={'c': True})
+
+    def test_return_none(self):
+        def f():
+            return None
+        self._run_as_operator(f)
+
+    def test_lambda(self):
+        with self.assertRaises(AirflowException):
+            PythonVirtualenvOperator(
+                python_callable=lambda x: 4,
+                task_id='task',
+                dag=self.dag)
+
+    def test_nonimported_as_arg(self):
+        def f(a):
+            return None
+        self._run_as_operator(f, op_args=[datetime.datetime.now()])