You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/02/18 17:42:19 UTC
[airflow] branch main updated: Fix @task.kubernetes to receive input and send output (#28942)
This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9a5c3e0ac0 Fix @task.kubernetes to receive input and send output (#28942)
9a5c3e0ac0 is described below
commit 9a5c3e0ac0b682d7f2c51727a56e06d68bc9f6be
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Sat Feb 18 12:42:11 2023 -0500
Fix @task.kubernetes to receive input and send output (#28942)
* Fix @task.kubernetes to receive input and send output
* Pickle input and rm unnecessary env vars
* Back to env vars and make cmds easier to read
* Remove check for op_args and op_kwargs on input write
---
.../cncf/kubernetes/decorators/kubernetes.py | 62 ++++++++++++----
.../kubernetes/python_kubernetes_script.jinja2 | 6 ++
.../cncf/kubernetes/decorators/test_kubernetes.py | 82 ++++++++++++++++++++--
3 files changed, 130 insertions(+), 20 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
index f68927c676..844d5300f5 100644
--- a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
@@ -16,14 +16,17 @@
# under the License.
from __future__ import annotations
+import base64
import inspect
import os
import pickle
import uuid
+from shlex import quote
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Sequence
+import dill
from kubernetes.client import models as k8s
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
@@ -37,21 +40,20 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
+_PYTHON_INPUT_ENV = "__PYTHON_INPUT"
-_FILENAME_IN_CONTAINER = "/tmp/script.py"
-
-def _generate_decode_command() -> str:
+def _generate_decoded_command(env_var: str, file: str) -> str:
return (
f'python -c "import base64, os;'
- rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
- rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
+ rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
+ rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
)
-def _read_file_contents(filename):
- with open(filename) as script_file:
- return script_file.read()
+def _read_file_contents(filename: str) -> str:
+ with open(filename, "rb") as script_file:
+ return base64.b64encode(script_file.read()).decode("utf-8")
class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
@@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
{"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
)
- # since we won't mutate the arguments, we should just do the shallow copy
+ # Since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs: Sequence[str] = ("python_callable",)
- def __init__(self, namespace: str = "default", **kwargs) -> None:
- self.pickling_library = pickle
+ def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
+ self.pickling_library = dill if use_dill else pickle
super().__init__(
namespace=namespace,
name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
- cmds=["bash"],
- arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
+ cmds=["dummy-command"],
**kwargs,
)
@@ -82,11 +83,41 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
res = remove_task_decorator(res, "@task.kubernetes")
return res
+ def _generate_cmds(self) -> list[str]:
+ script_filename = "/tmp/script.py"
+ input_filename = "/tmp/script.in"
+ output_filename = "/airflow/xcom/return.json"
+
+ write_local_script_file_cmd = (
+ f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
+ )
+ write_local_input_file_cmd = (
+ f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
+ )
+ make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
+ exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
+ return [
+ "bash",
+ "-cx",
+ " && ".join(
+ [
+ write_local_script_file_cmd,
+ write_local_input_file_cmd,
+ make_xcom_dir_cmd,
+ exec_python_cmd,
+ ]
+ ),
+ ]
+
def execute(self, context: Context):
with TemporaryDirectory(prefix="venv") as tmp_dir:
script_filename = os.path.join(tmp_dir, "script.py")
- py_source = self._get_python_source()
+ input_filename = os.path.join(tmp_dir, "script.in")
+
+ with open(input_filename, "wb") as file:
+ self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)
+ py_source = self._get_python_source()
jinja_context = {
"op_args": self.op_args,
"op_kwargs": self.op_kwargs,
@@ -100,7 +131,10 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
self.env_vars = [
*self.env_vars,
k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
+ k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
]
+
+ self.cmds = self._generate_cmds()
return super().execute(context)
diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
index c961f10de4..4042c07fc4 100644
--- a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
@@ -17,6 +17,7 @@
under the License.
-#}
+import json
import {{ pickling_library }}
import sys
@@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}}
# Script
{{ python_callable_source }}
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
+
+# Write output
+with open(sys.argv[2], "w") as file:
+ if res is not None:
+ file.write(json.dumps(res))
diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
index 46b087688c..584df8de03 100644
--- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+import base64
+import pickle
from unittest import mock
import pytest
@@ -29,6 +31,8 @@ KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod"
POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
+XCOM_IMAGE = "XCOM_IMAGE"
+
@pytest.fixture(autouse=True)
def mock_create_pod() -> mock.Mock:
@@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock:
return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()
+@pytest.fixture(autouse=True)
+def await_xcom_sidecar_container_start() -> mock.Mock:
+ return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()
+
+
+@pytest.fixture(autouse=True)
+def extract_xcom() -> mock.Mock:
+ f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
+ f.return_value = '{"key1": "value1", "key2": "value2"}'
+ return f
+
+
@pytest.fixture(autouse=True)
def mock_await_pod_completion() -> mock.Mock:
f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
@@ -81,11 +97,65 @@ def test_basic_kubernetes(dag_maker, session, mock_create_pod: mock.Mock, mock_h
containers = mock_create_pod.call_args[1]["pod"].spec.containers
assert len(containers) == 1
- assert containers[0].command == ["bash"]
+ assert containers[0].command[0] == "bash"
+ assert len(containers[0].args) == 0
+ assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+ assert containers[0].env[0].value
+ assert containers[0].env[1].name == "__PYTHON_INPUT"
+
+ # Ensure we pass input through a b64 encoded env var
+ decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
+ assert decoded_input == {"args": [], "kwargs": {}}
+
+
+def test_kubernetes_with_input_output(
+ dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+ with dag_maker(session=session) as dag:
+
+ @task.kubernetes(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ def f(arg1, arg2, kwarg1=None, kwarg2=None):
+ return {"key1": "value1", "key2": "value2"}
+
+ f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
+
+ dr = dag_maker.create_dagrun()
+ (ti,) = dr.task_instances
+
+ mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE
+
+ dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))
+
+ mock_hook.assert_called_once_with(
+ conn_id=None,
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ assert mock_create_pod.call_count == 1
+ assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
+
+ containers = mock_create_pod.call_args[1]["pod"].spec.containers
+
+ # First container is Python script
+ assert len(containers) == 2
+ assert containers[0].command[0] == "bash"
+ assert len(containers[0].args) == 0
+
+ assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+ assert containers[0].env[0].value
+ assert containers[0].env[1].name == "__PYTHON_INPUT"
+ assert containers[0].env[1].value
- assert len(containers[0].args) == 2
- assert containers[0].args[0] == "-cx"
- assert containers[0].args[1].endswith("/tmp/script.py")
+ # Ensure we pass input through a b64 encoded env var
+ decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
+ assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}}
- assert containers[0].env[-1].name == "__PYTHON_SCRIPT"
- assert containers[0].env[-1].value
+ # Second container is xcom image
+ assert containers[1].image == XCOM_IMAGE
+ assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"