You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/30 20:33:01 UTC
[airflow] branch master updated: Stop loading Extra Operator links
in Scheduler (#13932)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 7034529 Stop loading Extra Operator links in Scheduler (#13932)
7034529 is described below
commit 70345293031b56a6ce4019efe66ea9762d96c316
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Sat Jan 30 20:32:50 2021 +0000
Stop loading Extra Operator links in Scheduler (#13932)
closes #13099
---
airflow/jobs/scheduler_job.py | 2 +-
airflow/models/dagbag.py | 9 ++++
airflow/models/serialized_dag.py | 4 ++
airflow/serialization/serialized_objects.py | 66 +++++++++++++++++------------
tests/jobs/test_scheduler_job.py | 37 ++++++++++++++++
5 files changed, 91 insertions(+), 27 deletions(-)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index becfa60..49ebcd0 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -731,7 +731,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query')
self.processor_agent: Optional[DagFileProcessorAgent] = None
- self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True)
+ self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
def register_signals(self) -> None:
"""Register signals that stop child processes"""
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index f1ae55c..b493334 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -79,6 +79,10 @@ class DagBag(LoggingMixin):
:param read_dags_from_db: Read DAGs from DB if ``True`` is passed.
If ``False`` DAGs are read from python files.
:type read_dags_from_db: bool
+ :param load_op_links: Should the extra operator link be loaded via plugins when
+ de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
+ are not loaded to not run User code in Scheduler.
+ :type load_op_links: bool
"""
DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT')
@@ -92,6 +96,7 @@ class DagBag(LoggingMixin):
safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
read_dags_from_db: bool = False,
store_serialized_dags: Optional[bool] = None,
+ load_op_links: bool = True,
):
# Avoid circular import
from airflow.models.dag import DAG
@@ -128,6 +133,9 @@ class DagBag(LoggingMixin):
include_smart_sensor=include_smart_sensor,
safe_mode=safe_mode,
)
+ # Should the extra operator link be loaded via plugins?
+ # This flag is set to False in Scheduler so that Extra Operator links are not loaded
+ self.load_op_links = load_op_links
def size(self) -> int:
""":return: the amount of dags contained in this dagbag"""
@@ -226,6 +234,7 @@ class DagBag(LoggingMixin):
if not row:
raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table")
+ row.load_op_links = self.load_op_links
dag = row.dag
for subdag in dag.subdags:
self.dags[subdag.dag_id] = subdag
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index e12b29f..0184307 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -86,6 +86,8 @@ class SerializedDagModel(Base):
backref=backref('serialized_dag', uselist=False, innerjoin=True),
)
+ load_op_links = True
+
def __init__(self, dag: DAG):
self.dag_id = dag.dag_id
self.fileloc = dag.full_filepath
@@ -163,6 +165,8 @@ class SerializedDagModel(Base):
@property
def dag(self):
"""The DAG deserialized from the ``data`` column"""
+ SerializedDAG._load_operator_extra_links = self.load_op_links # pylint: disable=protected-access
+
if isinstance(self.data, dict):
dag = SerializedDAG.from_dict(self.data) # type: Any
else:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d0cb1a6..f59ce6a 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -100,6 +100,11 @@ class BaseSerialization:
_json_schema: Optional[Validator] = None
+ # Should the extra operator link be loaded via plugins when
+ # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
+ # are not loaded to not run User code in Scheduler.
+ _load_operator_extra_links = True
+
_CONSTRUCTOR_PARAMS: Dict[str, Parameter] = {}
SERIALIZER_VERSION = 1
@@ -407,35 +412,38 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
"""Deserializes an operator from a JSON object."""
- from airflow import plugins_manager
-
- plugins_manager.initialize_extra_operators_links_plugins()
-
- if plugins_manager.operator_extra_links is None:
- raise AirflowException("Can not load plugins")
op = SerializedBaseOperator(task_id=encoded_op['task_id'])
- # Extra Operator Links defined in Plugins
- op_extra_links_from_plugin = {}
-
if "label" not in encoded_op:
# Handle deserialization of old data before the introduction of TaskGroup
encoded_op["label"] = encoded_op["task_id"]
- for ope in plugins_manager.operator_extra_links:
- for operator in ope.operators:
- if (
- operator.__name__ == encoded_op["_task_type"]
- and operator.__module__ == encoded_op["_task_module"]
- ):
- op_extra_links_from_plugin.update({ope.name: ope})
-
- # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
- # set the Operator links attribute
- # The case for "If OperatorLinks are defined in the operator that is being Serialized"
- # is handled in the deserialization loop where it matches k == "_operator_extra_links"
- if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
- setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
+ # Extra Operator Links defined in Plugins
+ op_extra_links_from_plugin = {}
+
+ # We don't want to load Extra Operator links in Scheduler
+ if cls._load_operator_extra_links: # pylint: disable=too-many-nested-blocks
+ from airflow import plugins_manager
+
+ plugins_manager.initialize_extra_operators_links_plugins()
+
+ if plugins_manager.operator_extra_links is None:
+ raise AirflowException("Can not load plugins")
+
+ for ope in plugins_manager.operator_extra_links:
+ for operator in ope.operators:
+ if (
+ operator.__name__ == encoded_op["_task_type"]
+ and operator.__module__ == encoded_op["_task_module"]
+ ):
+ op_extra_links_from_plugin.update({ope.name: ope})
+
+ # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
+ # set the Operator links attribute
+ # The case for "If OperatorLinks are defined in the operator that is being Serialized"
+ # is handled in the deserialization loop where it matches k == "_operator_extra_links"
+ if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
+ setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))
for k, v in encoded_op.items():
@@ -450,10 +458,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
elif k.endswith("_date"):
v = cls._deserialize_datetime(v)
elif k == "_operator_extra_links":
- op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
+ if cls._load_operator_extra_links:
+ op_predefined_extra_links = cls._deserialize_operator_extra_links(v)
- # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
- op_predefined_extra_links.update(op_extra_links_from_plugin)
+ # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
+ op_predefined_extra_links.update(op_extra_links_from_plugin)
+ else:
+ op_predefined_extra_links = {}
v = list(op_predefined_extra_links.values())
k = "operator_extra_links"
@@ -655,6 +666,9 @@ class SerializedDAG(DAG, BaseSerialization):
if k == "_downstream_task_ids":
v = set(v)
elif k == "tasks":
+ # pylint: disable=protected-access
+ SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
+ # pylint: enable=protected-access
v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
k = "task_dict"
elif k == "timezone":
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 595da25..884280c 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -68,6 +68,7 @@ from tests.test_utils.db import (
set_default_pool_slots,
)
from tests.test_utils.mock_executor import MockExecutor
+from tests.test_utils.mock_operators import CustomOperator
ROOT_FOLDER = os.path.realpath(
os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
@@ -3560,6 +3561,42 @@ class TestSchedulerJob(unittest.TestCase):
assert dag.get_last_dagrun().creating_job_id == scheduler.id
+ def test_extra_operator_links_not_loaded_in_scheduler_loop(self):
+ """
+ Test that Operator links are not loaded inside the Scheduling Loop (that does not include
+ DagFileProcessorProcess) especially the critical loop of the Scheduler.
+
+ This is to avoid running User code in the Scheduler and prevent any deadlocks
+ """
+ dag = DAG(dag_id='test_extra_operator_links_not_loaded_in_scheduler', start_date=DEFAULT_DATE)
+
+ # This CustomOperator has Extra Operator Links registered via plugins
+ _ = CustomOperator(task_id='custom_task', dag=dag)
+
+ dagbag = DagBag(
+ dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
+ include_examples=False,
+ read_dags_from_db=True,
+ )
+ dagbag.bag_dag(dag=dag, root_dag=dag)
+ dagbag.sync_to_db()
+
+ # Get serialized dag
+ s_dag_1 = dagbag.get_dag(dag.dag_id)
+ custom_task = s_dag_1.task_dict['custom_task']
+ # Test that custom_task has >= 1 Operator Links (after de-serialization)
+ assert custom_task.operator_extra_links
+
+ scheduler = SchedulerJob(executor=self.null_exec)
+ scheduler.processor_agent = mock.MagicMock()
+ scheduler._run_scheduler_loop()
+
+ # Get serialized dag
+ s_dag_2 = scheduler.dagbag.get_dag(dag.dag_id)
+ custom_task = s_dag_2.task_dict['custom_task']
+ # Test that custom_task has no Operator Links (after de-serialization) in the Scheduling Loop
+ assert not custom_task.operator_extra_links
+
def test_scheduler_create_dag_runs_does_not_raise_error(self):
"""
Test that scheduler._create_dag_runs does not raise an error when the DAG does not exist