You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/18 04:22:43 UTC

[GitHub] stale[bot] closed pull request #3467: [AIRFLOW-2568] Azure Container Instances operator

stale[bot] closed pull request #3467: [AIRFLOW-2568] Azure Container Instances operator
URL: https://github.com/apache/incubator-airflow/pull/3467
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/azure_container_hook.py b/airflow/contrib/hooks/azure_container_hook.py
new file mode 100644
index 0000000000..f74e8fc86b
--- /dev/null
+++ b/airflow/contrib/hooks/azure_container_hook.py
@@ -0,0 +1,129 @@
+
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowException
+
+from azure.common.client_factory import get_client_from_auth_file
+from azure.common.credentials import ServicePrincipalCredentials
+
+from azure.mgmt.containerinstance import ContainerInstanceManagementClient
+from azure.mgmt.containerinstance.models import (ImageRegistryCredential,
+                                                 Volume,
+                                                 AzureFileVolume)
+
+
+class AzureContainerInstanceHook(BaseHook):
+
+    def __init__(self, conn_id='azure_default'):
+        self.conn_id = conn_id
+        self.connection = self.get_conn()
+
+    def get_conn(self):
+        conn = self.get_connection(self.conn_id)
+        key_path = conn.extra_dejson.get('key_path', False)
+        if key_path:
+            if key_path.endswith('.json'):
+                self.log.info('Getting connection using a JSON key file.')
+                return get_client_from_auth_file(ContainerInstanceManagementClient,
+                                                 key_path)
+            else:
+                raise AirflowException('Unrecognised extension for key file.')
+
+        if os.environ.get('AZURE_AUTH_LOCATION'):
+            key_path = os.environ.get('AZURE_AUTH_LOCATION')
+            if key_path.endswith('.json'):
+                self.log.info('Getting connection using a JSON key file.')
+                return get_client_from_auth_file(ContainerInstanceManagementClient,
+                                                 key_path)
+            else:
+                raise AirflowException('Unrecognised extension for key file.')
+
+        credentials = ServicePrincipalCredentials(
+            client_id=conn.login,
+            secret=conn.password,
+            tenant=conn.extra_dejson['tenantId']
+        )
+
+        subscription_id = conn.extra_dejson['subscriptionId']
+        return ContainerInstanceManagementClient(credentials, str(subscription_id))
+
+    def create_or_update(self, resource_group, name, container_group):
+        self.connection.container_groups.create_or_update(resource_group,
+                                                          name,
+                                                          container_group)
+
+    def get_state_exitcode(self, resource_group, name):
+        response = self.connection.container_groups.get(resource_group,
+                                                        name,
+                                                        raw=True).response.json()
+        containers = response['properties']['containers']
+        instance_view = containers[0]['properties'].get('instanceView', {})
+        current_state = instance_view.get('currentState', {})
+
+        return current_state.get('state'), current_state.get('exitCode', 0)
+
+    def get_messages(self, resource_group, name):
+        response = self.connection.container_groups.get(resource_group,
+                                                        name,
+                                                        raw=True).response.json()
+        containers = response['properties']['containers']
+        instance_view = containers[0]['properties'].get('instanceView', {})
+
+        return [event['message'] for event in instance_view.get('events', [])]
+
+    def get_logs(self, resource_group, name, tail=1000):
+        logs = self.connection.container_logs.list(resource_group, name, name, tail=tail)
+        return logs.content.splitlines(True)
+
+    def delete(self, resource_group, name):
+        self.connection.container_groups.delete(resource_group, name)
+
+
+class AzureContainerRegistryHook(BaseHook):
+
+    def __init__(self, conn_id='azure_registry'):
+        self.conn_id = conn_id
+        self.connection = self.get_conn()
+
+    def get_conn(self):
+        conn = self.get_connection(self.conn_id)
+        return ImageRegistryCredential(conn.host, conn.login, conn.password)
+
+
+class AzureContainerVolumeHook(BaseHook):
+
+    def __init__(self, wasb_conn_id='wasb_default'):
+        self.conn_id = wasb_conn_id
+
+    def get_storagekey(self):
+        conn = self.get_connection(self.conn_id)
+        service_options = conn.extra_dejson
+
+        if 'connection_string' in service_options:
+            for keyvalue in service_options['connection_string'].split(";"):
+                key, value = keyvalue.split("=", 1)
+                if key == "AccountKey":
+                    return value
+        return conn.password
+
+    def get_file_volume(self, mount_name, share_name,
+                        storage_account_name, read_only=False):
+        return Volume(mount_name,
+                      AzureFileVolume(share_name, storage_account_name,
+                                      read_only, self.get_storagekey()))
diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py
new file mode 100644
index 0000000000..c10ba59315
--- /dev/null
+++ b/airflow/contrib/operators/azure_container_instances_operator.py
@@ -0,0 +1,231 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from time import sleep
+
+from airflow.contrib.hooks.azure_container_hook import (AzureContainerInstanceHook,
+                                                        AzureContainerRegistryHook,
+                                                        AzureContainerVolumeHook)
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
+from airflow.models import BaseOperator
+
+from azure.mgmt.containerinstance.models import (EnvironmentVariable,
+                                                 VolumeMount,
+                                                 ResourceRequests,
+                                                 ResourceRequirements,
+                                                 Container,
+                                                 ContainerGroup)
+from msrestazure.azure_exceptions import CloudError
+
+
+class AzureContainerInstancesOperator(BaseOperator):
+    """
+    Start a container on Azure Container Instances
+
+    :param ci_conn_id: connection id of a service principal which will be used
+        to start the container instance
+    :type ci_conn_id: str
+    :param registry_conn_id: connection id of a user which can login to a
+        private docker registry. If None, we assume a public registry
+    :type registry_conn_id: str
+    :param resource_group: name of the resource group wherein this container
+        instance should be started
+    :type resource_group: str
+    :param name: name of this container instance. Please note this name has
+        to be unique in order to run containers in parallel.
+    :type name: str
+    :param image: the docker image to be used
+    :type image: str
+    :param region: the region wherein this container instance should be started
+    :type region: str
+    :param: environment_variables: key,value pairs containing environment variables
+        which will be passed to the running container
+    :type: environment_variables: dict
+    :param: volumes: list of volumes to be mounted to the container.
+        Currently only Azure Fileshares are supported.
+    :type: volumes: list[<conn_id, account_name, share_name, mount_path, read_only>]
+    :param: memory_in_gb: the amount of memory to allocate to this container
+    :type: memory_in_gb: double
+    :param: cpu: the number of cpus to allocate to this container
+    :type: cpu: double
+
+    :Example:
+
+    >>>  a = AzureContainerInstancesOperator(
+                'azure_service_principal',
+                'azure_registry_user',
+                'my-resource-group',
+                'my-container-name-{{ ds }}',
+                'myprivateregistry.azurecr.io/my_container:latest',
+                'westeurope',
+                {'EXECUTION_DATE': '{{ ds }}'},
+                [('azure_wasb_conn_id',
+                  'my_storage_container',
+                  'my_fileshare',
+                  '/input-data',
+                  True),],
+                memory_in_gb=14.0,
+                cpu=4.0,
+                task_id='start_container'
+            )
+    """
+
+    template_fields = ('name', 'environment_variables')
+    template_ext = tuple()
+
+    def __init__(self, ci_conn_id, registry_conn_id, resource_group, name, image, region,
+                 environment_variables={}, volumes=[], memory_in_gb=2.0, cpu=1.0,
+                 *args, **kwargs):
+        self.ci_conn_id = ci_conn_id
+        self.resource_group = resource_group
+        self.name = name
+        self.image = image
+        self.region = region
+        self.registry_conn_id = registry_conn_id
+        self.environment_variables = environment_variables
+        self.volumes = volumes
+        self.memory_in_gb = memory_in_gb
+        self.cpu = cpu
+
+        super(AzureContainerInstancesOperator, self).__init__(*args, **kwargs)
+
+    def execute(self, context):
+        ci_hook = AzureContainerInstanceHook(self.ci_conn_id)
+
+        if self.registry_conn_id:
+            registry_hook = AzureContainerRegistryHook(self.registry_conn_id)
+            image_registry_credentials = [registry_hook.connection, ]
+        else:
+            image_registry_credentials = None
+
+        environment_variables = []
+        for key, value in self.environment_variables.items():
+            environment_variables.append(EnvironmentVariable(key, value))
+
+        volumes = []
+        volume_mounts = []
+        for conn_id, account_name, share_name, mount_path, read_only in self.volumes:
+            hook = AzureContainerVolumeHook(conn_id)
+
+            mount_name = "mount-%d" % len(volumes)
+            volumes.append(hook.get_file_volume(mount_name,
+                                                share_name,
+                                                account_name,
+                                                read_only))
+            volume_mounts.append(VolumeMount(mount_name, mount_path, read_only))
+
+        try:
+            self.log.info("Starting container group with %.1f cpu %.1f mem",
+                          self.cpu, self.memory_in_gb)
+
+            resources = ResourceRequirements(ResourceRequests(
+                self.memory_in_gb,
+                self.cpu))
+
+            container = Container(
+                self.name, self.image, resources,
+                environment_variables=environment_variables,
+                volume_mounts=volume_mounts)
+
+            container_group = ContainerGroup(
+                location=self.region,
+                containers=[container, ],
+                image_registry_credentials=image_registry_credentials,
+                volumes=volumes,
+                restart_policy='Never',
+                os_type='Linux')
+
+            ci_hook.create_or_update(self.resource_group, self.name, container_group)
+
+            self.log.info("Container group started")
+
+            exit_code = self._monitor_logging(ci_hook, self.resource_group, self.name)
+
+            self.log.info("Container had exit code: %s", exit_code)
+            if exit_code != 0:
+                raise AirflowException("Container had a non-zero exit code, %s"
+                                       % exit_code)
+
+        except CloudError as e:
+            self.log.exception("Could not start container group")
+            raise AirflowException("Could not start container group")
+
+        finally:
+            self.log.info("Deleting container group")
+            try:
+                ci_hook.delete(self.resource_group, self.name)
+            except Exception:
+                self.log.exception("Could not delete container group")
+
+    def _monitor_logging(self, ci_hook, resource_group, name):
+        last_state = None
+        last_message_logged = None
+        last_line_logged = None
+        for _ in range(43200):  # roughly 12 hours
+            try:
+                state, exit_code = ci_hook.get_state_exitcode(resource_group, name)
+                if state != last_state:
+                    self.log.info("Container group state changed to %s", state)
+                    last_state = state
+
+                if state == "Terminated":
+                    return exit_code
+
+                messages = ci_hook.get_messages(resource_group, name)
+                last_message_logged = self._log_last(messages, last_message_logged)
+
+                if state == "Running":
+                    try:
+                        logs = ci_hook.get_logs(resource_group, name)
+                        last_line_logged = self._log_last(logs, last_line_logged)
+                    except CloudError as err:
+                        self.log.exception("Exception while getting logs from "
+                                           "container instance, retrying...")
+
+            except CloudError as err:
+                if 'ResourceNotFound' in str(err):
+                    self.log.warning("ResourceNotFound, container is probably removed "
+                                     "by another process "
+                                     "(make sure that the name is unique).")
+                    return 1
+                else:
+                    self.log.exception("Exception while getting container groups")
+            except Exception:
+                self.log.exception("Exception while getting container groups")
+
+            sleep(1)
+
+        # no return -> hence still running
+        raise AirflowTaskTimeout("Did not complete on time")
+
+    def _log_last(self, logs, last_line_logged):
+        if logs:
+            # determine the last line which was logged before
+            last_line_index = 0
+            for i in range(len(logs) - 1, -1, -1):
+                if logs[i] == last_line_logged:
+                    # this line is the same, hence print from i+1
+                    last_line_index = i + 1
+                    break
+
+            # log all new ones
+            for line in logs[last_line_index:]:
+                self.log.info(line.rstrip())
+
+            return logs[-1]
diff --git a/docs/integration.rst b/docs/integration.rst
index 3d436858b8..41ce1ee701 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -138,6 +138,43 @@ AzureDataLakeHook
 
 .. autoclass:: airflow.contrib.hooks.azure_data_lake_hook.AzureDataLakeHook
 
+Azure Container Instances
+'''''''''''''''''''''''''
+
+Azure Container Instances provides a method to run a docker container without having to worry
+about managing infrastructure. The AzureContainerInstanceHook requires a service principal. The
+credentials for this principal can either be defined in the extra field `key_path`, as an 
+environment variable named `AZURE_AUTH_LOCATION`, 
+or by providing a login/password and tenantId in extras.
+
+The AzureContainerRegistryHook requires a host/login/password to be defined in the connection.
+
+- :ref:`AzureContainerInstancesOperator` : Start/Monitor a new ACI.
+- :ref:`AzureContainerInstanceHook` : Wrapper around a single ACI.
+- :ref:`AzureContainerRegistryHook` : Wrapper around a ACR
+- :ref:`AzureContainerVolumeHook` : Wrapper around Container Volumes
+
+AzureContainerInstancesOperator
+"""""""""""""""""""""""""""""""
+
+.. autoclass:: airflow.contrib.operators.azure_container_instances_operator.AzureContainerInstancesOperator
+
+AzureContainerInstanceHook
+""""""""""""""""""""""""""
+
+.. autoclass:: airflow.contrib.hooks.azure_container_hook.AzureContainerInstanceHook
+
+AzureContainerRegistryHook
+""""""""""""""""""""""""""
+
+.. autoclass:: airflow.contrib.hooks.azure_container_hook.AzureContainerRegistryHook
+
+AzureContainerVolumeHook
+""""""""""""""""""""""""
+
+.. autoclass:: airflow.contrib.hooks.azure_container_hook.AzureContainerVolumeHook
+
+
 .. _AWS:
 
 AWS: Amazon Web Services
diff --git a/setup.py b/setup.py
index 2a11f74cc6..07b961a510 100644
--- a/setup.py
+++ b/setup.py
@@ -114,6 +114,7 @@ def write_version(filename=os.path.join(*['airflow',
     'azure-mgmt-datalake-store==0.4.0',
     'azure-datalake-store==0.0.19'
 ]
+azure_container_instances = ['azure-mgmt-containerinstance']
 cassandra = ['cassandra-driver>=3.13.0']
 celery = [
     'celery>=4.1.1, <4.2.0',
@@ -220,7 +221,8 @@ def write_version(filename=os.path.join(*['airflow',
 devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle +
              docker + ssh + kubernetes + celery + azure_blob_storage + redis + gcp_api +
              datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
-             druid + pinot + segment + snowflake + elasticsearch + azure_data_lake, atlas)
+             druid + pinot + segment + snowflake + elasticsearch + azure_data_lake +
+             atlas + azure_container_instances)
 
 # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
 if PY3:
@@ -293,6 +295,7 @@ def do_setup():
             'async': async,
             'azure_blob_storage': azure_blob_storage,
             'azure_data_lake': azure_data_lake,
+            'azure_container_instances': azure_container_instances,
             'cassandra': cassandra,
             'celery': celery,
             'cgroups': cgroups,
diff --git a/tests/contrib/operators/test_azure_container_instances_operator.py b/tests/contrib/operators/test_azure_container_instances_operator.py
new file mode 100644
index 0000000000..d80398bc8e
--- /dev/null
+++ b/tests/contrib/operators/test_azure_container_instances_operator.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.azure_container_instances_operator import (
+    AzureContainerInstancesOperator)
+
+import unittest
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+class TestACIOperator(unittest.TestCase):
+
+    @mock.patch("airflow.contrib.operators."
+                "azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode.return_value = "Terminated", 0
+
+        aci = AzureContainerInstancesOperator(None, None,
+                                              'resource-group', 'container-name',
+                                              'container-image', 'region',
+                                              task_id='task')
+        aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        (called_rg, called_cn, called_cg), _ = \
+            aci_mock.return_value.create_or_update.call_args
+
+        self.assertEqual(called_rg, 'resource-group')
+        self.assertEqual(called_cn, 'container-name')
+
+        self.assertEqual(called_cg.location, 'region')
+        self.assertEqual(called_cg.image_registry_credentials, None)
+        self.assertEqual(called_cg.restart_policy, 'Never')
+        self.assertEqual(called_cg.os_type, 'Linux')
+
+        called_cg_container = called_cg.containers[0]
+        self.assertEqual(called_cg_container.name, 'container-name')
+        self.assertEqual(called_cg_container.image, 'container-image')
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+    @mock.patch("airflow.contrib.operators."
+                "azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute_with_failures(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode.return_value = "Terminated", 1
+
+        aci = AzureContainerInstancesOperator(None, None,
+                                              'resource-group', 'container-name',
+                                              'container-image', 'region',
+                                              task_id='task')
+        with self.assertRaises(AirflowException):
+            aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+    @mock.patch("airflow.contrib.operators."
+                "azure_container_instances_operator.AzureContainerInstanceHook")
+    def test_execute_with_messages_logs(self, aci_mock):
+        aci_mock.return_value.get_state_exitcode.side_effect = [("Running", 0),
+                                                                ("Terminated", 0)]
+        aci_mock.return_value.get_messages.return_value = ["test", "messages"]
+        aci_mock.return_value.get_logs.return_value = ["test", "logs"]
+
+        aci = AzureContainerInstancesOperator(None, None,
+                                              'resource-group', 'container-name',
+                                              'container-image', 'region',
+                                              task_id='task')
+        aci.execute(None)
+
+        self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
+        self.assertEqual(aci_mock.return_value.get_state_exitcode.call_count, 2)
+        self.assertEqual(aci_mock.return_value.get_messages.call_count, 1)
+        self.assertEqual(aci_mock.return_value.get_logs.call_count, 1)
+
+        self.assertEqual(aci_mock.return_value.delete.call_count, 1)
+
+
+if __name__ == '__main__':
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services