You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/11 20:42:37 UTC
[airflow] branch v1-10-test updated: Validate only task commands
are run by executors (#9178)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new 12a7822 Validate only task commands are run by executors (#9178)
12a7822 is described below
commit 12a7822a779ba4cfd6311d235446d74da9e9e985
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Mon Jun 8 17:30:33 2020 +0100
Validate only task commands are run by executors (#9178)
(cherry-picked from 6943b171da6537ad6721cc7527b24236f901ee04)
---
airflow/executors/celery_executor.py | 2 ++
airflow/executors/dask_executor.py | 3 +++
airflow/executors/kubernetes_executor.py | 6 ++++++
airflow/executors/local_executor.py | 2 ++
airflow/executors/sequential_executor.py | 2 ++
tests/executors/test_celery_executor.py | 31 +++++++++++++++++++++++++++----
6 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 42bf611..a8775c4 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -62,6 +62,8 @@ app = Celery(
@app.task
def execute_command(command_to_exec):
log = LoggingMixin().log
+ if command_to_exec[0:2] != ["airflow", "run"]:
+ raise ValueError('The command must start with ["airflow", "run"].')
log.info("Executing command in Celery: %s", command_to_exec)
env = os.environ.copy()
try:
diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py
index d322f34..7fb8f04 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -64,6 +64,9 @@ class DaskExecutor(BaseExecutor):
'All tasks will be run in the same cluster'
)
+ if command[0:2] != ["airflow", "run"]:
+ raise ValueError('The command must start with ["airflow", "run"].')
+
def airflow_run():
return subprocess.check_call(command, close_fds=True)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index cf40e43..b62462f 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -395,6 +395,12 @@ class AirflowKubernetesScheduler(LoggingMixin):
key, command, kube_executor_config = next_job
dag_id, task_id, execution_date, try_number = key
+ if isinstance(command, str):
+ command = [command]
+
+ if command[0] != "airflow":
+ raise ValueError('The first element of command must be equal to "airflow".')
+
config_pod = self.worker_configuration.make_pod(
namespace=self.namespace,
worker_uuid=self.worker_uuid,
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 2c3ba40..2e06f4f 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -225,6 +225,8 @@ class LocalExecutor(BaseExecutor):
self.impl.start()
def execute_async(self, key, command, queue=None, executor_config=None):
+ if command[0:2] != ["airflow", "run"]:
+ raise ValueError('The command must start with ["airflow", "run"].')
self.impl.execute_async(key=key, command=command)
def sync(self):
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 1542e33..a0013e6 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -38,6 +38,8 @@ class SequentialExecutor(BaseExecutor):
self.commands_to_run = []
def execute_async(self, key, command, queue=None, executor_config=None):
+ if command[0:2] != ["airflow", "run"]:
+ raise ValueError('The command must start with ["airflow", "run"].')
self.commands_to_run.append((key, command,))
def sync(self):
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index d9a15c7..0a2776e 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -34,6 +34,7 @@ from kombu.asynchronous import set_event_loop
from parameterized import parameterized
from airflow.configuration import conf
+from airflow.exceptions import AirflowException
from airflow.executors import celery_executor
from airflow.utils.state import State
@@ -84,14 +85,18 @@ class TestCeleryExecutor(unittest.TestCase):
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_celery_integration(self, broker_url):
- with self._prepare_app(broker_url) as app:
+ success_command = ['airflow', 'run', 'true', 'some_parameter']
+ fail_command = ['airflow', 'version']
+
+ def fake_execute_command(command):
+ if command != success_command:
+ raise AirflowException("fail")
+
+ with self._prepare_app(broker_url, execute=fake_execute_command) as app:
executor = celery_executor.CeleryExecutor()
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
- success_command = ['true', 'some_parameter']
- fail_command = ['false', 'some_parameter']
-
cached_celery_backend = celery_executor.execute_command.backend
task_tuples_to_send = [('success', 'fake_simple_ti', success_command,
celery_executor.celery_configuration['task_default_queue'],
@@ -184,6 +189,24 @@ class TestCeleryExecutor(unittest.TestCase):
mock.call('executor.running_tasks', mock.ANY)]
mock_stats_gauge.assert_has_calls(calls)
+ @parameterized.expand((
+ [['true'], ValueError],
+ [['airflow', 'version'], ValueError],
+ [['airflow', 'run'], None]
+ ))
+ @mock.patch('subprocess.check_call')
+ def test_command_validation(self, command, expected_exception, mock_check_call):
+ # Check that we validate _on the receiving_ side, not just sending side
+ if expected_exception:
+ with pytest.raises(expected_exception):
+ celery_executor.execute_command(command)
+ mock_check_call.assert_not_called()
+ else:
+ celery_executor.execute_command(command)
+ mock_check_call.assert_called_once_with(
+ command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY,
+ )
+
def test_operation_timeout_config():
assert celery_executor.OPERATION_TIMEOUT == 2