You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/02/18 17:42:19 UTC

[airflow] branch main updated: Fix @task.kubernetes to receive input and send output (#28942)

This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a5c3e0ac0 Fix @task.kubernetes to receive input and send output (#28942)
9a5c3e0ac0 is described below

commit 9a5c3e0ac0b682d7f2c51727a56e06d68bc9f6be
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Sat Feb 18 12:42:11 2023 -0500

    Fix @task.kubernetes to receive input and send output (#28942)
    
    * Fix @task.kubernetes to receive input and send output
    
    * Pickle input and rm unnecessary env vars
    
    * Back to env vars and make cmds easier to read
    
    * Remove check for op_args and op_kwargs on input write
---
 .../cncf/kubernetes/decorators/kubernetes.py       | 62 ++++++++++++----
 .../kubernetes/python_kubernetes_script.jinja2     |  6 ++
 .../cncf/kubernetes/decorators/test_kubernetes.py  | 82 ++++++++++++++++++++--
 3 files changed, 130 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
index f68927c676..844d5300f5 100644
--- a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
@@ -16,14 +16,17 @@
 # under the License.
 from __future__ import annotations
 
+import base64
 import inspect
 import os
 import pickle
 import uuid
+from shlex import quote
 from tempfile import TemporaryDirectory
 from textwrap import dedent
 from typing import TYPE_CHECKING, Callable, Sequence
 
+import dill
 from kubernetes.client import models as k8s
 
 from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
@@ -37,21 +40,20 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 _PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
+_PYTHON_INPUT_ENV = "__PYTHON_INPUT"
 
-_FILENAME_IN_CONTAINER = "/tmp/script.py"
 
-
-def _generate_decode_command() -> str:
+def _generate_decoded_command(env_var: str, file: str) -> str:
     return (
         f'python -c "import base64, os;'
-        rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
-        rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
+        rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
+        rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
     )
 
 
-def _read_file_contents(filename):
-    with open(filename) as script_file:
-        return script_file.read()
+def _read_file_contents(filename: str) -> str:
+    with open(filename, "rb") as script_file:
+        return base64.b64encode(script_file.read()).decode("utf-8")
 
 
 class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
@@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
         {"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
     )
 
-    # since we won't mutate the arguments, we should just do the shallow copy
+    # Since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
     shallow_copy_attrs: Sequence[str] = ("python_callable",)
 
-    def __init__(self, namespace: str = "default", **kwargs) -> None:
-        self.pickling_library = pickle
+    def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
+        self.pickling_library = dill if use_dill else pickle
         super().__init__(
             namespace=namespace,
             name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
-            cmds=["bash"],
-            arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
+            cmds=["dummy-command"],
             **kwargs,
         )
 
@@ -82,11 +83,41 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
         res = remove_task_decorator(res, "@task.kubernetes")
         return res
 
+    def _generate_cmds(self) -> list[str]:
+        script_filename = "/tmp/script.py"
+        input_filename = "/tmp/script.in"
+        output_filename = "/airflow/xcom/return.json"
+
+        write_local_script_file_cmd = (
+            f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
+        )
+        write_local_input_file_cmd = (
+            f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
+        )
+        make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
+        exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
+        return [
+            "bash",
+            "-cx",
+            " && ".join(
+                [
+                    write_local_script_file_cmd,
+                    write_local_input_file_cmd,
+                    make_xcom_dir_cmd,
+                    exec_python_cmd,
+                ]
+            ),
+        ]
+
     def execute(self, context: Context):
         with TemporaryDirectory(prefix="venv") as tmp_dir:
             script_filename = os.path.join(tmp_dir, "script.py")
-            py_source = self._get_python_source()
+            input_filename = os.path.join(tmp_dir, "script.in")
+
+            with open(input_filename, "wb") as file:
+                self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)
 
+            py_source = self._get_python_source()
             jinja_context = {
                 "op_args": self.op_args,
                 "op_kwargs": self.op_kwargs,
@@ -100,7 +131,10 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
             self.env_vars = [
                 *self.env_vars,
                 k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
+                k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
             ]
+
+            self.cmds = self._generate_cmds()
             return super().execute(context)
 
 
diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
index c961f10de4..4042c07fc4 100644
--- a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
@@ -17,6 +17,7 @@
  under the License.
 -#}
 
+import json
 import {{ pickling_library }}
 import sys
 
@@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}}
 # Script
 {{ python_callable_source }}
 res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
+
+# Write output
+with open(sys.argv[2], "w") as file:
+    if res is not None:
+        file.write(json.dumps(res))
diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
index 46b087688c..584df8de03 100644
--- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
@@ -16,6 +16,8 @@
 # under the License.
 from __future__ import annotations
 
+import base64
+import pickle
 from unittest import mock
 
 import pytest
@@ -29,6 +31,8 @@ KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod"
 POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
 HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
 
+XCOM_IMAGE = "XCOM_IMAGE"
+
 
 @pytest.fixture(autouse=True)
 def mock_create_pod() -> mock.Mock:
@@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock:
     return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()
 
 
+@pytest.fixture(autouse=True)
+def await_xcom_sidecar_container_start() -> mock.Mock:
+    return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()
+
+
+@pytest.fixture(autouse=True)
+def extract_xcom() -> mock.Mock:
+    f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
+    f.return_value = '{"key1": "value1", "key2": "value2"}'
+    return f
+
+
 @pytest.fixture(autouse=True)
 def mock_await_pod_completion() -> mock.Mock:
     f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
@@ -81,11 +97,65 @@ def test_basic_kubernetes(dag_maker, session, mock_create_pod: mock.Mock, mock_h
 
     containers = mock_create_pod.call_args[1]["pod"].spec.containers
     assert len(containers) == 1
-    assert containers[0].command == ["bash"]
+    assert containers[0].command[0] == "bash"
+    assert len(containers[0].args) == 0
+    assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+    assert containers[0].env[0].value
+    assert containers[0].env[1].name == "__PYTHON_INPUT"
+
+    # Ensure we pass input through a b64 encoded env var
+    decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
+    assert decoded_input == {"args": [], "kwargs": {}}
+
+
+def test_kubernetes_with_input_output(
+    dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
+) -> None:
+    with dag_maker(session=session) as dag:
+
+        @task.kubernetes(
+            image="python:3.10-slim-buster",
+            in_cluster=False,
+            cluster_context="default",
+            config_file="/tmp/fake_file",
+        )
+        def f(arg1, arg2, kwarg1=None, kwarg2=None):
+            return {"key1": "value1", "key2": "value2"}
+
+        f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
+
+    dr = dag_maker.create_dagrun()
+    (ti,) = dr.task_instances
+
+    mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE
+
+    dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))
+
+    mock_hook.assert_called_once_with(
+        conn_id=None,
+        in_cluster=False,
+        cluster_context="default",
+        config_file="/tmp/fake_file",
+    )
+    assert mock_create_pod.call_count == 1
+    assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
+
+    containers = mock_create_pod.call_args[1]["pod"].spec.containers
+
+    # First container is Python script
+    assert len(containers) == 2
+    assert containers[0].command[0] == "bash"
+    assert len(containers[0].args) == 0
+
+    assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+    assert containers[0].env[0].value
+    assert containers[0].env[1].name == "__PYTHON_INPUT"
+    assert containers[0].env[1].value
 
-    assert len(containers[0].args) == 2
-    assert containers[0].args[0] == "-cx"
-    assert containers[0].args[1].endswith("/tmp/script.py")
+    # Ensure we pass input through a b64 encoded env var
+    decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
+    assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}}
 
-    assert containers[0].env[-1].name == "__PYTHON_SCRIPT"
-    assert containers[0].env[-1].value
+    # Second container is xcom image
+    assert containers[1].image == XCOM_IMAGE
+    assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"