You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/11 20:42:37 UTC

[airflow] branch v1-10-test updated: Validate only task commands are run by executors (#9178)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new 12a7822  Validate only task commands are run by executors (#9178)
12a7822 is described below

commit 12a7822a779ba4cfd6311d235446d74da9e9e985
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Mon Jun 8 17:30:33 2020 +0100

    Validate only task commands are run by executors (#9178)
    
    (cherry-picked from 6943b171da6537ad6721cc7527b24236f901ee04)
---
 airflow/executors/celery_executor.py     |  2 ++
 airflow/executors/dask_executor.py       |  3 +++
 airflow/executors/kubernetes_executor.py |  6 ++++++
 airflow/executors/local_executor.py      |  2 ++
 airflow/executors/sequential_executor.py |  2 ++
 tests/executors/test_celery_executor.py  | 31 +++++++++++++++++++++++++++----
 6 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 42bf611..a8775c4 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -62,6 +62,8 @@ app = Celery(
 @app.task
 def execute_command(command_to_exec):
     log = LoggingMixin().log
+    if command_to_exec[0:2] != ["airflow", "run"]:
+        raise ValueError('The command must start with ["airflow", "run"].')
     log.info("Executing command in Celery: %s", command_to_exec)
     env = os.environ.copy()
     try:
diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py
index d322f34..7fb8f04 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -64,6 +64,9 @@ class DaskExecutor(BaseExecutor):
                 'All tasks will be run in the same cluster'
             )
 
+        if command[0:2] != ["airflow", "run"]:
+            raise ValueError('The command must start with ["airflow", "run"].')
+
         def airflow_run():
             return subprocess.check_call(command, close_fds=True)
 
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index cf40e43..b62462f 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -395,6 +395,12 @@ class AirflowKubernetesScheduler(LoggingMixin):
         key, command, kube_executor_config = next_job
         dag_id, task_id, execution_date, try_number = key
 
+        if isinstance(command, str):
+            command = [command]
+
+        if command[0] != "airflow":
+            raise ValueError('The first element of command must be equal to "airflow".')
+
         config_pod = self.worker_configuration.make_pod(
             namespace=self.namespace,
             worker_uuid=self.worker_uuid,
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 2c3ba40..2e06f4f 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -225,6 +225,8 @@ class LocalExecutor(BaseExecutor):
         self.impl.start()
 
     def execute_async(self, key, command, queue=None, executor_config=None):
+        if command[0:2] != ["airflow", "run"]:
+            raise ValueError('The command must start with ["airflow", "run"].')
         self.impl.execute_async(key=key, command=command)
 
     def sync(self):
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 1542e33..a0013e6 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -38,6 +38,8 @@ class SequentialExecutor(BaseExecutor):
         self.commands_to_run = []
 
     def execute_async(self, key, command, queue=None, executor_config=None):
+        if command[0:2] != ["airflow", "run"]:
+            raise ValueError('The command must start with ["airflow", "run"].')
         self.commands_to_run.append((key, command,))
 
     def sync(self):
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index d9a15c7..0a2776e 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -34,6 +34,7 @@ from kombu.asynchronous import set_event_loop
 from parameterized import parameterized
 
 from airflow.configuration import conf
+from airflow.exceptions import AirflowException
 from airflow.executors import celery_executor
 from airflow.utils.state import State
 
@@ -84,14 +85,18 @@ class TestCeleryExecutor(unittest.TestCase):
     @pytest.mark.integration("rabbitmq")
     @pytest.mark.backend("mysql", "postgres")
     def test_celery_integration(self, broker_url):
-        with self._prepare_app(broker_url) as app:
+        success_command = ['airflow', 'run', 'true', 'some_parameter']
+        fail_command = ['airflow', 'version']
+
+        def fake_execute_command(command):
+            if command != success_command:
+                raise AirflowException("fail")
+
+        with self._prepare_app(broker_url, execute=fake_execute_command) as app:
             executor = celery_executor.CeleryExecutor()
             executor.start()
 
             with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
-                success_command = ['true', 'some_parameter']
-                fail_command = ['false', 'some_parameter']
-
                 cached_celery_backend = celery_executor.execute_command.backend
                 task_tuples_to_send = [('success', 'fake_simple_ti', success_command,
                                         celery_executor.celery_configuration['task_default_queue'],
@@ -184,6 +189,24 @@ class TestCeleryExecutor(unittest.TestCase):
                  mock.call('executor.running_tasks', mock.ANY)]
         mock_stats_gauge.assert_has_calls(calls)
 
+    @parameterized.expand((
+        [['true'], ValueError],
+        [['airflow', 'version'], ValueError],
+        [['airflow', 'run'], None]
+    ))
+    @mock.patch('subprocess.check_call')
+    def test_command_validation(self, command, expected_exception, mock_check_call):
+        # Check that we validate _on the receiving_ side, not just sending side
+        if expected_exception:
+            with pytest.raises(expected_exception):
+                celery_executor.execute_command(command)
+            mock_check_call.assert_not_called()
+        else:
+            celery_executor.execute_command(command)
+            mock_check_call.assert_called_once_with(
+                command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY,
+            )
+
 
 def test_operation_timeout_config():
     assert celery_executor.OPERATION_TIMEOUT == 2