You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/11 19:35:10 UTC

[GitHub] ashb closed pull request #4112: [AIRFLOW-3212] Add AwsGlueCatalogPartitionSensor

ashb closed pull request #4112: [AIRFLOW-3212] Add AwsGlueCatalogPartitionSensor
URL: https://github.com/apache/airflow/pull/4112
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/aws_glue_catalog_hook.py b/airflow/contrib/hooks/aws_glue_catalog_hook.py
new file mode 100644
index 0000000000..687f0fddb9
--- /dev/null
+++ b/airflow/contrib/hooks/aws_glue_catalog_hook.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+class AwsGlueCatalogHook(AwsHook):
+    """
+    Interact with AWS Glue Catalog
+
+    :param aws_conn_id: ID of the Airflow connection where
+        credentials and extra configuration are stored
+    :type aws_conn_id: str
+    :param region_name: aws region name (example: us-east-1)
+    :type region_name: str
+    """
+
+    def __init__(self,
+                 aws_conn_id='aws_default',
+                 region_name=None,
+                 *args,
+                 **kwargs):
+        self.region_name = region_name
+        super(AwsGlueCatalogHook, self).__init__(aws_conn_id=aws_conn_id, *args, **kwargs)
+
+    def get_conn(self):
+        """
+        Returns glue connection object.
+        """
+        self.conn = self.get_client_type('glue', self.region_name)
+        return self.conn
+
+    def get_partitions(self,
+                       database_name,
+                       table_name,
+                       expression='',
+                       page_size=None,
+                       max_items=None):
+        """
+        Retrieves the partition values for a table.
+
+        :param database_name: The name of the catalog database where the partitions reside.
+        :type database_name: str
+        :param table_name: The name of the partitions' table.
+        :type table_name: str
+        :param expression: An expression filtering the partitions to be returned.
+            Please see official AWS documentation for further information.
+            https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
+        :type expression: str
+        :param page_size: pagination size
+        :type page_size: int
+        :param max_items: maximum items to return
+        :type max_items: int
+        :return: set of partition values where each value is a tuple since
+            a partition may be composed of multiple columns. For example:
+        {('2018-01-01','1'), ('2018-01-01','2')}
+        """
+        config = {
+            'PageSize': page_size,
+            'MaxItems': max_items,
+        }
+
+        paginator = self.get_conn().get_paginator('get_partitions')
+        response = paginator.paginate(
+            DatabaseName=database_name,
+            TableName=table_name,
+            Expression=expression,
+            PaginationConfig=config
+        )
+
+        partitions = set()
+        for page in response:
+            for p in page['Partitions']:
+                partitions.add(tuple(p['Values']))
+
+        return partitions
+
+    def check_for_partition(self, database_name, table_name, expression):
+        """
+        Checks whether a partition exists
+
+        :param database_name: Name of hive database (schema) @table belongs to
+        :type database_name: str
+        :param table_name: Name of hive table @partition belongs to
+        :type table_name: str
+        :expression: Expression that matches the partitions to check for
+            (eg `a = 'b' AND c = 'd'`)
+        :type expression: str
+        :rtype: bool
+
+        >>> hook = AwsGlueCatalogHook()
+        >>> t = 'static_babynames_partitioned'
+        >>> hook.check_for_partition('airflow', t, "ds='2015-01-01'")
+        True
+        """
+        partitions = self.get_partitions(database_name, table_name, expression, max_items=1)
+
+        if partitions:
+            return True
+        else:
+            return False
diff --git a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
new file mode 100644
index 0000000000..74e25a29dc
--- /dev/null
+++ b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
+    """
+    Waits for a partition to show up in AWS Glue Catalog.
+
+    :param table_name: The name of the table to wait for, supports the dot
+        notation (my_database.my_table)
+    :type table_name: str
+    :param expression: The partition clause to wait for. This is passed as
+        is to the AWS Glue Catalog API's get_partitions function,
+        and supports SQL like notation as in ``ds='2015-01-01'
+        AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
+        See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
+        #aws-glue-api-catalog-partitions-GetPartitions
+    :type expression: str
+    :param aws_conn_id: ID of the Airflow connection where
+        credentials and extra configuration are stored
+    :type aws_conn_id: str
+    :param region_name: Optional aws region name (example: us-east-1). Uses region from connection
+        if not specified.
+    :type region_name: str
+    :param database_name: The name of the catalog database where the partitions reside.
+    :type database_name: str
+    :param poke_interval: Time in seconds that the job should wait in
+        between each tries
+    :type poke_interval: int
+    """
+    template_fields = ('database_name', 'table_name', 'expression',)
+    ui_color = '#C5CAE9'
+
+    @apply_defaults
+    def __init__(self,
+                 table_name, expression="ds='{{ ds }}'",
+                 aws_conn_id='aws_default',
+                 region_name=None,
+                 database_name='default',
+                 poke_interval=60 * 3,
+                 *args,
+                 **kwargs):
+        super(AwsGlueCatalogPartitionSensor, self).__init__(
+            poke_interval=poke_interval, *args, **kwargs)
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.table_name = table_name
+        self.expression = expression
+        self.database_name = database_name
+
+    def poke(self, context):
+        """
+        Checks for existence of the partition in the AWS Glue Catalog table
+        """
+        if '.' in self.table_name:
+            self.database_name, self.table_name = self.table_name.split('.')
+        self.log.info(
+            'Poking for table {self.database_name}.{self.table_name}, '
+            'expression {self.expression}'.format(**locals()))
+
+        return self.get_hook().check_for_partition(
+            self.database_name, self.table_name, self.expression)
+
+    def get_hook(self):
+        """
+        Gets the AwsGlueCatalogHook
+        """
+        if not hasattr(self, 'hook'):
+            from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
+            self.hook = AwsGlueCatalogHook(
+                aws_conn_id=self.aws_conn_id,
+                region_name=self.region_name)
+
+        return self.hook
diff --git a/docs/code.rst b/docs/code.rst
index a670a2d8fd..4da2b32f7e 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -235,6 +235,7 @@ Sensors
 ^^^^^^^
 
 .. autoclass:: airflow.contrib.sensors.aws_athena_sensor.AthenaSensor
+.. autoclass:: airflow.contrib.sensors.aws_glue_catalog_partition_sensor.AwsGlueCatalogPartitionSensor
 .. autoclass:: airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
 .. autoclass:: airflow.contrib.sensors.azure_cosmos_sensor.AzureCosmosDocumentSensor
 .. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
@@ -419,6 +420,7 @@ Community contributed hooks
 .. autoclass:: airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook
 .. autoclass:: airflow.contrib.hooks.aws_dynamodb_hook.AwsDynamoDBHook
 .. autoclass:: airflow.contrib.hooks.aws_firehose_hook.AwsFirehoseHook
+.. autoclass:: airflow.contrib.hooks.aws_glue_catalog_hook.AwsGlueCatalogHook
 .. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
 .. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
 .. autoclass:: airflow.contrib.hooks.aws_sns_hook.AwsSnsHook
diff --git a/setup.py b/setup.py
index 410502c302..42421e0329 100644
--- a/setup.py
+++ b/setup.py
@@ -251,7 +251,7 @@ def write_version(filename=os.path.join(*['airflow',
     'lxml>=4.0.0',
     'mock',
     'mongomock',
-    'moto==1.1.19',
+    'moto==1.3.5',
     'nose',
     'nose-ignore-docstring==0.2',
     'nose-timer',
diff --git a/tests/contrib/hooks/test_aws_glue_catalog_hook.py b/tests/contrib/hooks/test_aws_glue_catalog_hook.py
new file mode 100644
index 0000000000..cddc5cfb6f
--- /dev/null
+++ b/tests/contrib/hooks/test_aws_glue_catalog_hook.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
+
+try:
+    from moto import mock_glue
+except ImportError:
+    mock_glue = None
+
+try:
+    from unittest import mock
+except ImportError:
+    import mock
+
+
+@unittest.skipIf(mock_glue is None,
+                 "Skipping test because moto.mock_glue is not available")
+class TestAwsGlueCatalogHook(unittest.TestCase):
+
+    @mock_glue
+    def test_get_conn_returns_a_boto3_connection(self):
+        hook = AwsGlueCatalogHook(region_name="us-east-1")
+        self.assertIsNotNone(hook.get_conn())
+
+    @mock_glue
+    def test_conn_id(self):
+        hook = AwsGlueCatalogHook(aws_conn_id='my_aws_conn_id', region_name="us-east-1")
+        self.assertEquals(hook.aws_conn_id, 'my_aws_conn_id')
+
+    @mock_glue
+    def test_region(self):
+        hook = AwsGlueCatalogHook(region_name="us-west-2")
+        self.assertEquals(hook.region_name, 'us-west-2')
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'get_conn')
+    def test_get_partitions_empty(self, mock_get_conn):
+        response = set()
+        mock_get_conn.get_paginator.paginate.return_value = response
+        hook = AwsGlueCatalogHook(region_name="us-east-1")
+
+        self.assertEquals(hook.get_partitions('db', 'tbl'), set())
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'get_conn')
+    def test_get_partitions(self, mock_get_conn):
+        response = [{
+            'Partitions': [{
+                'Values': ['2015-01-01']
+            }]
+        }]
+        mock_paginator = mock.Mock()
+        mock_paginator.paginate.return_value = response
+        mock_conn = mock.Mock()
+        mock_conn.get_paginator.return_value = mock_paginator
+        mock_get_conn.return_value = mock_conn
+        hook = AwsGlueCatalogHook(region_name="us-east-1")
+        result = hook.get_partitions('db',
+                                     'tbl',
+                                     expression='foo=bar',
+                                     page_size=2,
+                                     max_items=3)
+
+        self.assertEquals(result, set([('2015-01-01',)]))
+        mock_conn.get_paginator.assert_called_once_with('get_partitions')
+        mock_paginator.paginate.assert_called_once_with(DatabaseName='db',
+                                                        TableName='tbl',
+                                                        Expression='foo=bar',
+                                                        PaginationConfig={
+                                                            'PageSize': 2,
+                                                            'MaxItems': 3})
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
+    def test_check_for_partition(self, mock_get_partitions):
+        mock_get_partitions.return_value = set([('2018-01-01',)])
+        hook = AwsGlueCatalogHook(region_name="us-east-1")
+
+        self.assertTrue(hook.check_for_partition('db', 'tbl', 'expr'))
+        mock_get_partitions.assert_called_once_with('db', 'tbl', 'expr', max_items=1)
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'get_partitions')
+    def test_check_for_partition_false(self, mock_get_partitions):
+        mock_get_partitions.return_value = set()
+        hook = AwsGlueCatalogHook(region_name="us-east-1")
+
+        self.assertFalse(hook.check_for_partition('db', 'tbl', 'expr'))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/sensors/test_aws_glue_catalog_partition_sensor.py b/tests/contrib/sensors/test_aws_glue_catalog_partition_sensor.py
new file mode 100644
index 0000000000..fc12318e1a
--- /dev/null
+++ b/tests/contrib/sensors/test_aws_glue_catalog_partition_sensor.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow import configuration
+from airflow.contrib.hooks.aws_glue_catalog_hook import AwsGlueCatalogHook
+from airflow.contrib.sensors.aws_glue_catalog_partition_sensor import AwsGlueCatalogPartitionSensor
+
+try:
+    from moto import mock_glue
+except ImportError:
+    mock_glue = None
+
+try:
+    from unittest import mock
+except ImportError:
+    import mock
+
+
+@unittest.skipIf(mock_glue is None,
+                 "Skipping test because moto.mock_glue is not available")
+class TestAwsGlueCatalogPartitionSensor(unittest.TestCase):
+
+    task_id = 'test_glue_catalog_partition_sensor'
+
+    def setUp(self):
+        configuration.load_test_config()
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
+    def test_poke(self, mock_check_for_partition):
+        mock_check_for_partition.return_value = True
+        op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
+                                           table_name='tbl')
+        self.assertTrue(op.poke(None))
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
+    def test_poke_false(self, mock_check_for_partition):
+        mock_check_for_partition.return_value = False
+        op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
+                                           table_name='tbl')
+        self.assertFalse(op.poke(None))
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
+    def test_poke_default_args(self, mock_check_for_partition):
+        table_name = 'test_glue_catalog_partition_sensor_tbl'
+        op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
+                                           table_name=table_name)
+        op.poke(None)
+
+        self.assertEqual(op.hook.region_name, None)
+        self.assertEqual(op.hook.aws_conn_id, 'aws_default')
+        mock_check_for_partition.assert_called_once_with('default',
+                                                         table_name,
+                                                         "ds='{{ ds }}'")
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
+    def test_poke_nondefault_args(self, mock_check_for_partition):
+        table_name = 'my_table'
+        expression = 'col=val'
+        aws_conn_id = 'my_aws_conn_id'
+        region_name = 'us-west-2'
+        database_name = 'my_db'
+        poke_interval = 2
+        timeout = 3
+        op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
+                                           table_name=table_name,
+                                           expression=expression,
+                                           aws_conn_id=aws_conn_id,
+                                           region_name=region_name,
+                                           database_name=database_name,
+                                           poke_interval=poke_interval,
+                                           timeout=timeout)
+        op.poke(None)
+
+        self.assertEqual(op.hook.region_name, region_name)
+        self.assertEqual(op.hook.aws_conn_id, aws_conn_id)
+        self.assertEqual(op.poke_interval, poke_interval)
+        self.assertEqual(op.timeout, timeout)
+        mock_check_for_partition.assert_called_once_with(database_name,
+                                                         table_name,
+                                                         expression)
+
+    @mock_glue
+    @mock.patch.object(AwsGlueCatalogHook, 'check_for_partition')
+    def test_dot_notation(self, mock_check_for_partition):
+        db_table = 'my_db.my_tbl'
+        op = AwsGlueCatalogPartitionSensor(task_id=self.task_id,
+                                           table_name=db_table)
+        op.poke(None)
+
+        mock_check_for_partition.assert_called_once_with('my_db',
+                                                         'my_tbl',
+                                                         "ds='{{ ds }}'")
+
+
+if __name__ == '__main__':
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services