You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/05/13 15:59:43 UTC

[airflow] branch master updated: Ensure that task preceeding a PythonVirtualenvOperator doesn't fail (#15822)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ab9c0c  Ensure that task preceeding a PythonVirtualenvOperator doesn't fail (#15822)
8ab9c0c is described below

commit 8ab9c0c969559318417b9e66454f7a95a34aeeeb
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Thu May 13 16:59:08 2021 +0100

    Ensure that task preceeding a PythonVirtualenvOperator doesn't fail (#15822)
    
    The addition in 2.0.0 of the "mini scheduler run" at the end of a task
    would cause any task preceeding a PythonVirtualenvOperator to fail with
    an exception of `cannot pickle 'module' object`.
---
 airflow/models/dag.py          |  9 ++-------
 airflow/operators/python.py    |  5 +++++
 tests/operators/test_python.py | 13 +++++++++++++
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index d1bf560..35fa842 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1463,13 +1463,8 @@ class DAG(LoggingMixin):
         """
         # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all
         # the tasks anyway, so we copy the tasks manually later
-        task_dict = self.task_dict
-        task_group = self._task_group
-        self.task_dict = {}
-        self._task_group = None  # type: ignore
-        dag = copy.deepcopy(self)
-        self.task_dict = task_dict
-        self._task_group = task_group
+        memo = {id(self.task_dict): None, id(self._task_group): None}
+        dag = copy.deepcopy(self, memo)  # type: ignore
 
         if isinstance(task_ids_or_regex, (str, RePatternType)):
             matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)]
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index e43425a..fa8020c 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -432,6 +432,11 @@ class PythonVirtualenvOperator(PythonOperator):
                 )
                 raise
 
+    def __deepcopy__(self, memo):
+        # module objects can't be copied _at all__
+        memo[id(self.pickling_library)] = self.pickling_library
+        return super().__deepcopy__(memo)
+
 
 def get_current_context() -> Dict[str, Any]:
     """
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 3853d3a..eae84b1 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -1025,6 +1025,19 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
 
         self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None)
 
+    def test_deepcopy(self):
+        """Test that PythonVirtualenvOperator are deep-copyable."""
+
+        def f():
+            return 1
+
+        task = PythonVirtualenvOperator(
+            python_callable=f,
+            task_id='task',
+            dag=self.dag,
+        )
+        copy.deepcopy(task)
+
 
 DEFAULT_ARGS = {
     "owner": "test",