You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "ASF GitHub Bot (JIRA)" <ji...@apache.org> on 2018/12/18 04:22:01 UTC

[jira] [Commented] (AIRFLOW-1310) Kubernetes execute operator

    [ https://issues.apache.org/jira/browse/AIRFLOW-1310?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16723665#comment-16723665 ] 

ASF GitHub Bot commented on AIRFLOW-1310:
-----------------------------------------

stale[bot] closed pull request #2456: [AIRFLOW-1310] Basic operator to run docker container on Kubernetes
URL: https://github.com/apache/incubator-airflow/pull/2456
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/kubernetes_hook.py b/airflow/contrib/hooks/kubernetes_hook.py
new file mode 100644
index 0000000000..fbabf81966
--- /dev/null
+++ b/airflow/contrib/hooks/kubernetes_hook.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import requests
+import json
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+
+from kubernetes import client, config
+
+class KubernetesHook(BaseHook):
+    """
+    Kubernetes interaction hook
+
+    :param k8s_conn_id: reference to a pre-defined K8s Connection
+    :type k8s_conn_id: string
+    """
+
+    def __init__(self, k8s_conn_id="k8s_default"):
+        self.conn_id = k8s_conn_id
+        self.core_client = None
+
+    def get_conn(self):
+        """
+        Initializes the api client. Only config file or env
+        configuration supported at the moment.
+        """
+        if not self.core_client:
+            config.load_kube_config()
+            self.core_client = client.CoreV1Api()
+
+        return self.core_client
+
+    def get_env_definitions(self, env):
+        def get_env(name, definition):
+            if isinstance(definition, str):
+                return client.V1EnvVar(name=name, value=definition)
+            elif isinstance(definition, dict):
+                source = definition['source']
+                if source == 'configMap':
+                    return client.V1EnvVar(name=name,
+                            value_from=client.V1EnvVarSource(
+                                config_map_key_ref=client.V1ConfigMapKeySelector(
+                                    key=definition['key'], name=definition['name'])))
+                elif source == 'secret':
+                    return client.V1EnvVar(name=name,
+                            value_from=client.V1EnvVarSource(
+                                secret_key_ref=client.V1SecretKeySelector(
+                                    key=definition['key'], name=definition['name'])))
+                else:
+                    raise AirflowException('Creating env vars from %s not implemented',
+                            source)
+            else:
+                raise AirflowException('Environment variable definition \
+                        has to be either string or a dictionary. %s given instead',
+                        type(definition))
+
+        return [get_env(name, definition) for name, definition in env.items()]
+
+    def get_env_from_definitions(self, env_from):
+        def get_env_from(definition):
+            configmap = definition.get('configMap')
+            secret = definition.get('secret')
+            prefix = definition.get('prefix')
+
+            cfg_ref = client.V1ConfigMapEnvSource(name=configmap) if configmap else None
+            secret_ref = client.V1SecretEnvSource(name=secret) if secret else None
+
+            return client.V1EnvFromSource(
+                config_map_ref=cfg_ref,
+                secret_ref=secret_ref,
+                prefix=prefix
+            )
+        return [get_env_from(definition) for definition in env_from]
+
+    def get_volume_definitions(self, volumes):
+        def get_volume(name, definition):
+            if definition['type'] == 'emptyDir':
+                volume = client.V1Volume(
+                    name=name,
+                    empty_dir=client.V1EmptyDirVolumeSource()
+                )
+                volume_mount = client.V1VolumeMount(
+                    mount_path=definition['mountPath'],
+                    name=name
+                )
+            elif definition['type'] == 'hostPath':
+                volume = client.V1Volume(
+                    name=name,
+                    host_path=client.V1HostPathVolumeSource(
+                        path=definition['path']
+                    )
+                )
+                volume_mount = client.V1VolumeMount(
+                    mount_path=definition['mountPath'],
+                    name=name
+                )
+            elif definition['type'] == 'secret':
+                volume = client.V1Volume(
+                    name=name,
+                    secret=client.V1SecretVolumeSource(
+                        secret_name=definition['secret']
+                    )
+                )
+                volume_mount = client.V1VolumeMount(
+                    mount_path=definition['mountPath'],
+                    name=name
+                )
+            else:
+               raise AirflowException('Volume source %s not implemented',
+                    definition['type'])
+
+            return (volume, volume_mount)
+
+        [volume_defs, volume_mount_defs] = \
+                zip(*[get_volume(name, definition) for name, definition in volumes.items()])
+        return (list(volume_defs), list(volume_mount_defs))
+
+    def get_pod_definition(
+            self,
+            image,
+            name,
+            namespace=None,
+            restart_policy="Never",
+            command=None,
+            args=None,
+            env=None,
+            env_from=None,
+            volumes=None,
+            labels=None):
+        """
+            Builds pod definition based on supplied arguments
+        """
+        env_defs = self.get_env_definitions(env) if env else None
+        env_from_defs = self.get_env_from_definitions(env_from) if env_from else None
+        volume_defs, volume_mount_defs = \
+                self.get_volume_definitions(volumes) if volumes else (None, None)
+
+        return client.V1Pod(
+            api_version="v1",
+            kind="Pod",
+            metadata=client.V1ObjectMeta(
+                name=name,
+                namespace=namespace,
+                labels=labels
+            ),
+            spec=client.V1PodSpec(
+                restart_policy=restart_policy,
+                containers=[client.V1Container(
+                    name=name,
+                    command=command,
+                    args=args,
+                    image=image,
+                    env=env_defs,
+                    env_from=env_from_defs,
+                    volume_mounts=volume_mount_defs
+                )],
+                volumes=volume_defs
+            )
+        )
+
+    def create_pod(self, pod):
+        namespace = pod.metadata.namespace
+        self.get_conn().create_namespaced_pod(namespace, pod)
+
+    def delete_pod(
+            self,
+            pod=None,
+            name=None,
+            namespace=None):
+        """
+            Delete a pod based on pod definition or name
+        """
+        if pod:
+            name = pod.metadata.name
+            namespace = pod.metadata.namespace
+        self.get_conn().delete_namespaced_pod(name, namespace, client.V1DeleteOptions())
+
+    def get_pod_state(self, pod=None, name=None, namespace=None):
+        """
+            Fetches pod status and returns phase
+        """
+        if pod:
+            name = pod.metadata.name
+            namespace = pod.metadata.namespace
+
+        pod_status = self.get_conn().read_namespaced_pod_status(name, namespace)
+
+        if not pod_status:
+            raise AirflowException("Cannot find the requested pod!")
+
+        return pod_status.status.phase
+
+    def relay_pod_logs(self, pod=None, name=None, namespace=None):
+        if pod:
+            name = pod.metadata.name
+            namespace = pod.metadata.namespace
+
+        logging.info("Start container log")
+        logging.info("-------------------")
+
+        if not self._stream_log(name, namespace):
+            self._client_log(name, namespace)
+
+    def _get_authorization(self):
+        if client.configuration.api_key['authorization'] is not None:
+            return {'Authorization': client.configuration.api_key['authorization']}
+        else:
+            return None
+
+    def _stream_log(self, name, namespace):
+        """
+            Stream logs for pod.
+            The python-client for kubernetes does not (yet) support iterating over a
+            streaming log.
+
+            Only bearer authenticated requests for now.
+            (Which is enough if running the worker in kubernetes)
+        """
+        headers = self._get_authorization()
+        if not headers:
+            return False
+
+        try:
+            url = "%s/api/v1/namespaces/%s/pods/%s/log" % \
+                    (client.configuration.host, namespace, name)
+            with requests.get(url,
+                             params={'follow':'true'},
+                             verify=client.configuration.ssl_ca_cert,
+                             headers=headers,
+                             stream=True) as r:
+
+                if r.encoding is None:
+                    r.encoding = 'utf-8'
+
+                for line in r.iter_lines(decode_unicode=True):
+                    logging.info(line.strip())
+        except Exception as e:
+            logging.info("Streaming container log terminated unexpectedly: %s", e)
+            return False
+
+        return True
+
+    def _client_log(self, name, namespace):
+        """
+            Fetch log from k8s client.
+            read_namespaced_pod_log with follow=True, only returns once the log is
+            closed.
+        """
+        try:
+            log = self.get_conn().read_namespaced_pod_log(
+                    name,
+                    namespace,
+                    follow=True)
+
+            log_lines = log.rstrip().split("\n")
+            for line in log_lines:
+                logging.info(line.rstrip())
+        except Exception as e:
+            logging.info("Container log from client terminated unexpectedly: %s", e)
+
+    def relay_pod_events(self, pod=None, name=None, namespace=None, timeout=60):
+        """
+            Stream kubernetes events for the pod into logging.info
+
+            Watches the events for the specified pod, until either an event with
+            reason "Started" is encountered, or a timeout is reached. Some events
+            might be missed as the api does not necessarily return
+            events in order, however this should not be a real problem as the value
+            of these are diagnosing startup problems.
+        """
+        if pod:
+            name = pod.metadata.name
+            namespace = pod.metadata.namespace
+
+        headers = self._get_authorization()
+        if not headers:
+            return False
+
+        params = {
+            'fieldSelector': 'involvedObject.name=%s' % name,
+            'watch': 'true',
+            'timeoutSeconds': timeout
+        }
+
+        url = "%s/api/v1/namespaces/%s/events/" % (
+                client.configuration.host,
+                namespace)
+
+        logging.info("Start pod event log")
+
+        try:
+            with requests.get(url,
+                             params=params,
+                             verify=client.configuration.ssl_ca_cert,
+                             headers=headers,
+                             stream=True) as r:
+
+                if r.encoding is None:
+                    r.encoding = 'utf-8'
+
+                for line in r.iter_lines(decode_unicode=True):
+                    data = json.loads(line)
+
+                    ob = data['object']
+
+                    logging.info("event: %s (component: %s, host: %s, reason: %s)",
+                            ob.get('message'),
+                            ob.get('source', {}).get('component'),
+                            ob.get('source', {}).get('host'),
+                            ob.get('reason'),
+                    )
+
+                    if ob.get('reason') == "Started":
+                        break
+
+        except Exception as e:
+            logging.info("Streaming events terminated unexpectedly: %s", e)
+            return False
+
+        return True
diff --git a/airflow/contrib/operators/kubernetes_operator.py b/airflow/contrib/operators/kubernetes_operator.py
new file mode 100644
index 0000000000..bac9163558
--- /dev/null
+++ b/airflow/contrib/operators/kubernetes_operator.py
@@ -0,0 +1,192 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from time import sleep, time
+from random import randint
+
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+from airflow.contrib.hooks.kubernetes_hook import KubernetesHook
+
+
+class KubernetesPodOperator(BaseOperator):
+    """
+    Deploys docker container to k8s pod and waits for its completion
+
+    :param name: Name of the pod, optional if not given an unique name will
+        be created automatically
+    :type name: string
+    :param namespace: Namespace the pod will be deployed to
+    :type namespace: string
+    :param image: Fully qualified name of the image in form
+        of repo/image:tag
+    :type image: string
+    :param command: Commands to execute in the image,
+        default image command will be executed if none supplied
+    :type command: string or list
+    :param op_args: Arguments for the command
+    :type op_args: list
+    :param wait: Wait for the completion. Default True. If set to false
+        the operator waits for the pod to start running to ensure
+        successful creation.
+    :type wait: boolean
+    :param unique_name: Whether the operator should ensure the uniqueness
+        of the pod's name. Default is true
+    :type unique_name: boolean
+    :param cleanup: Perform cleanup on completion.
+        Allowed values: Always, Never, OnSuccess, OnFailure.
+        Default Always. Settign wait == False forces Never,
+        as cleanup can be only performed on terminated container.
+    :param labels: Labels and presets to apply to the pod.
+    :type labels: dict
+    :param env: Environment variables defintion as a dictionary
+        of a form name:definition, where definition is a string or
+        a dictionary with following fields:
+        source (configMap|secret), name, and key
+    :type env: dict
+    :param conn_id: Id of pre-defined k8s connection. Currently not used,
+        as only preconfigured environment with kube config or env variables
+        is supported.
+    :type conn_id: string
+    :param poke_interval: Interval between checking the status in seconds
+    :type poke_interval: integer
+    :param wait_timeout: Time in seconds to wait for the pod to reach running state.
+    :type wait_timeout: integer
+
+    """
+    template_fields = ('name', 'command', 'op_args', 'namespace', 'env')
+    ui_color = '#f0ede4'
+
+    @apply_defaults
+    def __init__(
+            self,
+            image,
+            name=None,
+            namespace="default",
+            command=None,
+            op_args=None,
+            wait=True,
+            unique_name=True,
+            cleanup="Always",
+            labels=None,
+            env=None,
+            env_from=None,
+            volumes=None,
+            conn_id="k8s_default",
+            poke_interval=3,
+            wait_timeout=60,
+            *args, **kwargs):
+        super(KubernetesPodOperator, self).__init__(*args, **kwargs)
+        self.image = image
+        self.name = name
+        self.namespace = namespace
+        self.command = command
+        self.op_args = op_args
+        self.wait = wait
+        self.unique_name = unique_name
+        self.cleanup = cleanup if self.wait else "Never"
+        self.labels = labels
+        self.env = env
+        self.env_from = env_from
+        self.volumes = volumes
+        self.poke_interval = poke_interval
+        self.wait_timeout = wait_timeout
+        self.conn_id = conn_id
+
+    def _create_hook(self):
+        return KubernetesHook(self.conn_id)
+
+    def _base_name(self, context):
+        if self.name is not None:
+            return self.name
+
+        base_name = "%s-%s" % (context['ti'].dag_id, context['ti'].task_id)
+        r = re.compile('[^a-z0-9-]+')
+        base_name = r.sub('-', base_name)
+        self.name = base_name
+        return base_name
+
+    def _unique_name(self, context):
+        name = self._base_name(context)
+
+        if not self.unique_name:
+            return name
+
+        job_id = context['ti'].job_id
+        if job_id is None:
+            # job_id is None when running "airflow test"
+            job_id = int(time()*1000)
+
+        return "%s-%s" % (name, job_id)
+
+    def should_do_cleanup(self, status):
+        return ((self.cleanup == "Always") or
+                ((self.cleanup == "OnFailure") and (status == "Failed")) or
+                ((self.cleanup == "OnSuccess") and (status == "Succeeded")))
+
+    def execute(self, context):
+        exit_statuses = ["Succeeded"] if self.wait else ["Running", "Succeeded"]
+
+        hook = self._create_hook()
+
+        pod_name = self._unique_name(context)
+        pod = hook.get_pod_definition(
+            image=self.image,
+            name=pod_name,
+            namespace=self.namespace,
+            restart_policy="Never",
+            command=self.command,
+            args=self.op_args,
+            env=self.env,
+            env_from=self.env_from,
+            volumes=self.volumes,
+            labels=self.labels)
+
+        logging.info("Creating pod %s in namespace %s",
+                pod_name, self.namespace)
+
+        logging.debug("Pod definition: %s", pod.spec)
+
+        hook.create_pod(pod)
+        max_wait = time() + self.wait_timeout
+
+        try:
+            hook.relay_pod_events(pod, timeout=self.wait_timeout)
+
+            status = None
+            while hook.get_pod_state(pod) == 'Pending':
+                if time() > max_wait:
+                    raise AirflowException("Timeout while waiting for \
+                            pod to reach Running state.")
+
+                sleep(self.poke_interval)
+
+            if self.wait:
+                hook.relay_pod_logs(pod)
+
+            while not (status in exit_statuses):
+                status = hook.get_pod_state(pod)
+                logging.info("Checking pod status => %s", status)
+
+                if (status == "Failed"):
+                    raise AirflowException("Pod failed!")
+
+                sleep(self.poke_interval)
+        finally:
+            if (self.wait and self.should_do_cleanup(status)):
+                hook.delete_pod(pod)
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 670335c8b4..cc581f195d 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -49,6 +49,7 @@ ipython
 jaydebeapi
 jinja2<2.9.0
 jira
+kubernetes>=2.0.0
 ldap3
 lxml
 markdown
diff --git a/setup.py b/setup.py
index dedcf76794..f2d470fcad 100644
--- a/setup.py
+++ b/setup.py
@@ -148,6 +148,7 @@ def check_previous():
 hdfs = ['snakebite>=2.7.8']
 webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4']
 jira = ['JIRA>1.0.7']
+kubernetes = ['kubernetes>=2.0.0']
 hive = [
     'hive-thrift-py>=0.0.1',
     'pyhive>=0.1.3',
@@ -202,7 +203,7 @@ def check_previous():
 ]
 devel_minreq = devel + mysql + doc + password + s3 + cgroups
 devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
-devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + ssh
+devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + kubernetes + ssh
 
 
 def do_setup():
@@ -274,6 +275,7 @@ def do_setup():
             'hive': hive,
             'jdbc': jdbc,
             'kerberos': kerberos,
+            'kubernetes': kubernetes,
             'ldap': ldap,
             'mssql': mssql,
             'mysql': mysql,
diff --git a/tests/contrib/hooks/test_kubernetes_hook.py b/tests/contrib/hooks/test_kubernetes_hook.py
new file mode 100644
index 0000000000..8419464af8
--- /dev/null
+++ b/tests/contrib/hooks/test_kubernetes_hook.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import mock
+
+from airflow import configuration, models
+from airflow.contrib.hooks.kubernetes_hook import KubernetesHook
+from airflow.exceptions import AirflowException
+from airflow.utils import db
+from mock import patch, call, Mock, MagicMock
+
+from io import StringIO
+
+class TestKubernetesHook(unittest.TestCase):
+    def setUp(self):
+        super(TestKubernetesHook, self).setUp()
+
+        self.conn_mock = mock.MagicMock(name='coreV1Api')
+        self.get_conn_orig = KubernetesHook.get_conn
+
+        def _get_conn_mock(hook):
+            hook.core_client = self.conn_mock
+            return self.conn_mock
+
+        KubernetesHook.get_conn = _get_conn_mock
+
+    def tearDown(self):
+        KubernetesHook.get_conn = self.get_conn_orig
+        super(TestKubernetesHook, self).tearDown()
+
+    def test_get_pod_definition(self):
+        hook = KubernetesHook()
+
+        pod = hook.get_pod_definition(name="test_pod", namespace="test_namespace", image="image:tag")
+        self.assertTrue(pod.metadata.name == "test_pod")
+        self.assertTrue(pod.metadata.namespace == "test_namespace")
+        self.assertTrue(pod.metadata.labels is None)
+
+        self.assertTrue(len(pod.spec.containers) == 1)
+        self.assertTrue(pod.spec.containers[0].name == "test_pod")
+        self.assertTrue(pod.spec.containers[0].image == "image:tag")
+        self.assertTrue(pod.spec.containers[0].command is None)
+        self.assertTrue(pod.spec.containers[0].args is None)
+
+    def test_get_pod_definition_with_extras(self):
+        hook = KubernetesHook()
+
+        pod = hook.get_pod_definition(
+                name="test_pod",
+                namespace="test_namespace",
+                image="image:tag",
+                command=["python", "test.py"],
+                args=["arg1", "arg2"],
+                env={
+                    'param1': 'string!',
+                    'param2': {'source': 'configMap', 'name': 'config1', 'key': 'key1'},
+                    'param3': {'source': 'secret', 'name': 'config1', 'key': 'key1'}},
+                labels={"label1": "true", "label2": "true"})
+
+        self.assertTrue(pod.metadata.name == "test_pod")
+        self.assertTrue(pod.metadata.namespace == "test_namespace")
+        self.assertTrue(pod.metadata.labels == {"label1": "true", "label2": "true"})
+
+        self.assertTrue(len(pod.spec.containers) == 1)
+        self.assertTrue(pod.spec.containers[0].name == "test_pod")
+        self.assertTrue(pod.spec.containers[0].image == "image:tag")
+        
+        result = pod.spec.containers[0].env
+        result.sort(key = lambda x: x.name)
+        self.assertTrue(result[0].name == 'param1' and result[0].value == 'string!')
+        self.assertTrue(result[1].name == 'param2' and
+                result[1].value_from.config_map_key_ref.name == 'config1' and
+                result[1].value_from.config_map_key_ref.key == 'key1')
+        self.assertTrue(result[2].name == 'param3' and
+                result[2].value_from.secret_key_ref.name == 'config1' and
+                result[2].value_from.secret_key_ref.key == 'key1')
+
+        self.assertTrue(pod.spec.containers[0].command == ["python", "test.py"])
+        self.assertTrue(pod.spec.containers[0].args == ["arg1", "arg2"])
+
+    def test_create_pod(self):
+        hook = KubernetesHook()
+
+        pod = Mock()
+
+        hook.create_pod(pod)
+        hook.get_conn().create_namespaced_pod.assert_called_once_with(pod.metadata.namespace, pod)
+
+    def test_delete_pod(self):
+        hook = KubernetesHook()
+
+        pod = Mock()
+
+        hook.delete_pod(pod)
+        hook.get_conn().delete_namespaced_pod.assert_called_once_with(pod.metadata.name, pod.metadata.namespace, mock.ANY)
+
+    def test_delete_pod_by_name(self):
+        hook = KubernetesHook()
+
+        hook.delete_pod(name="test_pod", namespace="test_namespace")
+        hook.get_conn().delete_namespaced_pod.assert_called_once_with("test_pod", "test_namespace", mock.ANY)
+
+    def test_get_pod_status_not_found(self):
+        hook = KubernetesHook()
+
+        hook.get_conn().read_namespaced_pod_status.return_value = None
+
+        pod = Mock()
+        with self.assertRaises(AirflowException) as context:
+            hook.get_pod_state(pod)
+
+        hook.get_conn().read_namespaced_pod_status.assert_called_once_with(pod.metadata.name, pod.metadata.namespace)
+
+    def test_get_pod_status(self):
+        hook = KubernetesHook()
+
+        pod_mock = Mock()
+        pod_mock.metadata.name = "test_pod"
+        pod_mock.metadata.namespace = "test_namespace"
+        pod_mock.status.phase = "Running"
+        self.conn_mock.read_namespaced_pod_status.return_value = pod_mock
+
+        state = hook.get_pod_state(pod_mock)
+
+        self.assertTrue(state == "Running")
+        hook.get_conn().read_namespaced_pod_status.assert_called_once_with(pod_mock.metadata.name, pod_mock.metadata.namespace)
+
+    def test_get_pod_status_by_name(self):
+        hook = KubernetesHook()
+
+        pod_mock = Mock()
+        pod_mock.metadata.name = "test_pod"
+        pod_mock.metadata.namespace = "test_namespace"
+        pod_mock.status.phase = "Running"
+        self.conn_mock.read_namespaced_pod_status.return_value = pod_mock
+
+        state = hook.get_pod_state(name=pod_mock.metadata.name, namespace=pod_mock.metadata.namespace)
+
+        self.assertTrue(state == "Running")
+        hook.get_conn().read_namespaced_pod_status.assert_called_once_with(pod_mock.metadata.name, pod_mock.metadata.namespace)
+
diff --git a/tests/contrib/operators/test_kubernetes_operator.py b/tests/contrib/operators/test_kubernetes_operator.py
new file mode 100644
index 0000000000..505cb7db39
--- /dev/null
+++ b/tests/contrib/operators/test_kubernetes_operator.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import datetime
+import sys
+
+import mock
+from mock import MagicMock, Mock
+
+from airflow import DAG, configuration
+from airflow.models import TaskInstance
+
+from airflow.contrib.operators.kubernetes_operator import KubernetesPodOperator
+from airflow.contrib.hooks.kubernetes_hook import KubernetesHook
+from airflow.exceptions import AirflowException
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+END_DATE = datetime.datetime(2016, 1, 2)
+INTERVAL = datetime.timedelta(hours=12)
+
+class TestKubernetesOperator(unittest.TestCase):
+    def setUp(self):
+        super(TestKubernetesOperator, self).setUp()
+        self._create_hook_orig = KubernetesPodOperator._create_hook
+        
+        self.context_mock = MagicMock()
+        self.context_mock.__getitem__.side_effect = {'ti': type('', (object,), {'job_id': 1234})}.__getitem__
+
+        self.pod_mock = MagicMock()
+        self.pod_mock.metadata.name = "pod_name-1234"
+        self.pod_mock.metadata.namespace = "pod_namespace"
+
+        self.hook_mock = MagicMock(spec=KubernetesHook)
+
+        def _create_hook_mock(sensor):
+            return self.hook_mock
+
+        KubernetesPodOperator._create_hook = _create_hook_mock
+
+    def tearDown(self):
+        TestKubernetesOperator._create_hook = self._create_hook_orig
+        super(TestKubernetesOperator, self).tearDown()
+
+    def test_should_do_cleanup(self):
+        def get_operator(cleanup):
+            return KubernetesPodOperator(
+                    task_id="task",
+                    name=None,
+                    namespace=None,
+                    image=None,
+                    cleanup=cleanup)
+
+        self.assertFalse(get_operator("Never").should_do_cleanup("Succeeded"))
+        self.assertFalse(get_operator("Never").should_do_cleanup("Failed"))
+        self.assertTrue(get_operator("Always").should_do_cleanup("Succeeded"))
+        self.assertTrue(get_operator("Always").should_do_cleanup("Failed"))
+        self.assertTrue(get_operator("OnSuccess").should_do_cleanup("Succeeded"))
+        self.assertFalse(get_operator("OnSuccess").should_do_cleanup("Failed"))
+        self.assertTrue(get_operator("OnFailure").should_do_cleanup("Failed"))
+        self.assertFalse(get_operator("OnFailure").should_do_cleanup("Succeeded"))
+
+    def test_execute_container_fails(self):
+        self.hook_mock.reset_mock()
+        self.context_mock.reset_mock()
+        self.pod_mock.reset_mock()
+
+        self.hook_mock.get_pod_definition.return_value = self.pod_mock
+        self.hook_mock.get_pod_state.side_effect = ["Pending", "Failed", "Failed"]
+
+        operator = KubernetesPodOperator(
+                task_id='k8s_test',
+                name="pod_name",
+                namespace="pod_namespace",
+                image="image:test",
+                poke_interval=0,
+                cleanup="Always")
+
+        with self.assertRaises(AirflowException):
+            operator.execute(self.context_mock)
+
+        self.hook_mock.delete_pod.assert_called_once_with(self.pod_mock)
+
+    def test_execute_container(self):
+        self.hook_mock.reset_mock()
+        self.context_mock.reset_mock()
+        self.pod_mock.reset_mock()
+
+        self.hook_mock.get_pod_definition.return_value = self.pod_mock
+        self.hook_mock.get_pod_state.side_effect = ["Pending", "Running", "Succeeded"]
+
+        operator = KubernetesPodOperator(
+                task_id='k8s_test',
+                name="pod_name",
+                namespace="pod_namespace",
+                image="image:test",
+                poke_interval=0,
+                cleanup="Always")
+
+        operator.execute(self.context_mock)
+
+        self.hook_mock.get_pod_definition.assert_called_once_with(
+                args=None,
+                command=None,
+                image='image:test',
+                labels=None,
+                env=None,
+                env_from=None,
+                name="pod_name-1234",
+                namespace='pod_namespace',
+                restart_policy='Never',
+                volumes=None)
+        self.hook_mock.create_pod.assert_called_once_with(self.pod_mock)
+        self.hook_mock.delete_pod.assert_called_once_with(self.pod_mock)
+        self.assertEqual(self.hook_mock.get_pod_state.call_count, 3)
+
+    def test_execute_container_no_wait(self):
+        self.hook_mock.reset_mock()
+        self.context_mock.reset_mock()
+        self.pod_mock.reset_mock()
+
+        self.hook_mock.get_pod_definition.return_value = self.pod_mock
+        self.hook_mock.get_pod_state.side_effect = ["Pending", "Running", "Succeeded"]
+
+        operator = KubernetesPodOperator(
+                task_id='k8s_test',
+                name="pod_name",
+                namespace="pod_namespace",
+                image="image:test",
+                env={'param1': 'val'},
+                poke_interval=0,
+                cleanup="Never",
+                wait=False)
+
+        operator.execute(self.context_mock)
+
+        self.hook_mock.get_pod_definition.assert_called_once_with(
+                args=None,
+                command=None,
+                image='image:test',
+                env={'param1': 'val'},
+                env_from=None,
+                labels=None,
+                name='pod_name-1234',
+                namespace='pod_namespace',
+                restart_policy='Never',
+                volumes=None)
+        self.hook_mock.create_pod.assert_called_once_with(self.pod_mock)
+        self.assertFalse(self.hook_mock.delete_pod.called)
+        self.hook_mock.get_pod_state.called_twice_with(self.pod_mock)
\ No newline at end of file


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> Kubernetes execute operator
> ---------------------------
>
>                 Key: AIRFLOW-1310
>                 URL: https://issues.apache.org/jira/browse/AIRFLOW-1310
>             Project: Apache Airflow
>          Issue Type: New Feature
>          Components: operators
>            Reporter: Yu Ishikawa
>            Assignee: Dennis Docter
>            Priority: Major
>
> We should support operators for Kubernetes to execute containers remotely.
> - Create a Kubernetes cluster
> - Execute PODs
> - https://github.com/kubernetes-incubator/client-python



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)