You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/30 20:33:01 UTC

[airflow] branch master updated: Stop loading Extra Operator links in Scheduler (#13932)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7034529  Stop loading Extra Operator links in Scheduler (#13932)
7034529 is described below

commit 70345293031b56a6ce4019efe66ea9762d96c316
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Sat Jan 30 20:32:50 2021 +0000

    Stop loading Extra Operator links in Scheduler (#13932)
    
    closes #13099
---
 airflow/jobs/scheduler_job.py               |  2 +-
 airflow/models/dagbag.py                    |  9 ++++
 airflow/models/serialized_dag.py            |  4 ++
 airflow/serialization/serialized_objects.py | 66 +++++++++++++++++------------
 tests/jobs/test_scheduler_job.py            | 37 ++++++++++++++++
 5 files changed, 91 insertions(+), 27 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index becfa60..49ebcd0 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -731,7 +731,7 @@ class SchedulerJob(BaseJob):  # pylint: disable=too-many-instance-attributes
         self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query')
         self.processor_agent: Optional[DagFileProcessorAgent] = None
 
-        self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True)
+        self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
 
     def register_signals(self) -> None:
         """Register signals that stop child processes"""
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index f1ae55c..b493334 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -79,6 +79,10 @@ class DagBag(LoggingMixin):
     :param read_dags_from_db: Read DAGs from DB if ``True`` is passed.
         If ``False`` DAGs are read from python files.
     :type read_dags_from_db: bool
+    :param load_op_links: Should the extra operator link be loaded via plugins when
+        de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
+        are not loaded to not run User code in Scheduler.
+    :type load_op_links: bool
     """
 
     DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT')
@@ -92,6 +96,7 @@ class DagBag(LoggingMixin):
         safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
         read_dags_from_db: bool = False,
         store_serialized_dags: Optional[bool] = None,
+        load_op_links: bool = True,
     ):
         # Avoid circular import
         from airflow.models.dag import DAG
@@ -128,6 +133,9 @@ class DagBag(LoggingMixin):
             include_smart_sensor=include_smart_sensor,
             safe_mode=safe_mode,
         )
+        # Should the extra operator link be loaded via plugins?
+        # This flag is set to False in Scheduler so that Extra Operator links are not loaded
+        self.load_op_links = load_op_links
 
     def size(self) -> int:
         """:return: the amount of dags contained in this dagbag"""
@@ -226,6 +234,7 @@ class DagBag(LoggingMixin):
         if not row:
             raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table")
 
+        row.load_op_links = self.load_op_links
         dag = row.dag
         for subdag in dag.subdags:
             self.dags[subdag.dag_id] = subdag
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index e12b29f..0184307 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -86,6 +86,8 @@ class SerializedDagModel(Base):
         backref=backref('serialized_dag', uselist=False, innerjoin=True),
     )
 
+    load_op_links = True
+
     def __init__(self, dag: DAG):
         self.dag_id = dag.dag_id
         self.fileloc = dag.full_filepath
@@ -163,6 +165,8 @@ class SerializedDagModel(Base):
     @property
     def dag(self):
         """The DAG deserialized from the ``data`` column"""
+        SerializedDAG._load_operator_extra_links = self.load_op_links  # pylint: disable=protected-access
+
         if isinstance(self.data, dict):
             dag = SerializedDAG.from_dict(self.data)  # type: Any
         else:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d0cb1a6..f59ce6a 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -100,6 +100,11 @@ class BaseSerialization:
 
     _json_schema: Optional[Validator] = None
 
+    # Should the extra operator link be loaded via plugins when
+    # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
+    # are not loaded to not run User code in Scheduler.
+    _load_operator_extra_links = True
+
     _CONSTRUCTOR_PARAMS: Dict[str, Parameter] = {}
 
     SERIALIZER_VERSION = 1
@@ -407,35 +412,38 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     @classmethod
     def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
         """Deserializes an operator from a JSON object."""
-        from airflow import plugins_manager
-
-        plugins_manager.initialize_extra_operators_links_plugins()
-
-        if plugins_manager.operator_extra_links is None:
-            raise AirflowException("Can not load plugins")
         op = SerializedBaseOperator(task_id=encoded_op['task_id'])
 
-        # Extra Operator Links defined in Plugins
-        op_extra_links_from_plugin = {}
-
         if "label" not in encoded_op:
             # Handle deserialization of old data before the introduction of TaskGroup
             encoded_op["label"] = encoded_op["task_id"]
 
-        for ope in plugins_manager.operator_extra_links:
-            for operator in ope.operators:
-                if (
-                    operator.__name__ == encoded_op["_task_type"]
-                    and operator.__module__ == encoded_op["_task_module"]
-                ):
-                    op_extra_links_from_plugin.update({ope.name: ope})
-
-        # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
-        # set the Operator links attribute
-        # The case for "If OperatorLinks are defined in the operator that is being Serialized"
-        # is handled in the deserialization loop where it matches k == "_operator_extra_links"
-        if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
-            setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
+        # Extra Operator Links defined in Plugins
+        op_extra_links_from_plugin = {}
+
+        # We don't want to load Extra Operator links in Scheduler
+        if cls._load_operator_extra_links:  # pylint: disable=too-many-nested-blocks
+            from airflow import plugins_manager
+
+            plugins_manager.initialize_extra_operators_links_plugins()
+
+            if plugins_manager.operator_extra_links is None:
+                raise AirflowException("Can not load plugins")
+
+            for ope in plugins_manager.operator_extra_links:
+                for operator in ope.operators:
+                    if (
+                        operator.__name__ == encoded_op["_task_type"]
+                        and operator.__module__ == encoded_op["_task_module"]
+                    ):
+                        op_extra_links_from_plugin.update({ope.name: ope})
+
+            # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
+            # set the Operator links attribute
+            # The case for "If OperatorLinks are defined in the operator that is being Serialized"
+            # is handled in the deserialization loop where it matches k == "_operator_extra_links"
+            if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
+                setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
 
         for k, v in encoded_op.items():
 
@@ -450,10 +458,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
             elif k.endswith("_date"):
                 v = cls._deserialize_datetime(v)
             elif k == "_operator_extra_links":
-                op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
+                if cls._load_operator_extra_links:
+                    op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
 
-                # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
-                op_predefined_extra_links.update(op_extra_links_from_plugin)
+                    # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
+                    op_predefined_extra_links.update(op_extra_links_from_plugin)
+                else:
+                    op_predefined_extra_links = {}
 
                 v = list(op_predefined_extra_links.values())
                 k = "operator_extra_links"
@@ -655,6 +666,9 @@ class SerializedDAG(DAG, BaseSerialization):
             if k == "_downstream_task_ids":
                 v = set(v)
             elif k == "tasks":
+                # pylint: disable=protected-access
+                SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
+                # pylint: enable=protected-access
                 v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
                 k = "task_dict"
             elif k == "timezone":
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 595da25..884280c 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -68,6 +68,7 @@ from tests.test_utils.db import (
     set_default_pool_slots,
 )
 from tests.test_utils.mock_executor import MockExecutor
+from tests.test_utils.mock_operators import CustomOperator
 
 ROOT_FOLDER = os.path.realpath(
     os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
@@ -3560,6 +3561,42 @@ class TestSchedulerJob(unittest.TestCase):
 
         assert dag.get_last_dagrun().creating_job_id == scheduler.id
 
+    def test_extra_operator_links_not_loaded_in_scheduler_loop(self):
+        """
+        Test that Operator links are not loaded inside the Scheduling Loop (that does not include
+        DagFileProcessorProcess) especially the critical loop of the Scheduler.
+
+        This is to avoid running User code in the Scheduler and prevent any deadlocks
+        """
+        dag = DAG(dag_id='test_extra_operator_links_not_loaded_in_scheduler', start_date=DEFAULT_DATE)
+
+        # This CustomOperator has Extra Operator Links registered via plugins
+        _ = CustomOperator(task_id='custom_task', dag=dag)
+
+        dagbag = DagBag(
+            dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
+            include_examples=False,
+            read_dags_from_db=True,
+        )
+        dagbag.bag_dag(dag=dag, root_dag=dag)
+        dagbag.sync_to_db()
+
+        # Get serialized dag
+        s_dag_1 = dagbag.get_dag(dag.dag_id)
+        custom_task = s_dag_1.task_dict['custom_task']
+        # Test that custom_task has >= 1 Operator Links (after de-serialization)
+        assert custom_task.operator_extra_links
+
+        scheduler = SchedulerJob(executor=self.null_exec)
+        scheduler.processor_agent = mock.MagicMock()
+        scheduler._run_scheduler_loop()
+
+        # Get serialized dag
+        s_dag_2 = scheduler.dagbag.get_dag(dag.dag_id)
+        custom_task = s_dag_2.task_dict['custom_task']
+        # Test that custom_task has no Operator Links (after de-serialization) in the Scheduling Loop
+        assert not custom_task.operator_extra_links
+
     def test_scheduler_create_dag_runs_does_not_raise_error(self):
         """
         Test that scheduler._create_dag_runs does not raise an error when the DAG does not exist