You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/05/08 10:42:00 UTC

incubator-airflow git commit: [AIRFLOW-2427] Add tests to named hive sensor

Repository: incubator-airflow
Updated Branches:
  refs/heads/master baf15e11a -> b18b437c2


[AIRFLOW-2427] Add tests to named hive sensor

Closes #3323 from gglanzani/AIRFLOW-2427


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b18b437c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b18b437c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b18b437c

Branch: refs/heads/master
Commit: b18b437c216b0c4b3ffb41e4934f3c2dd966c14b
Parents: baf15e1
Author: Giovanni Lanzani <gi...@lanzani.nl>
Authored: Tue May 8 12:41:51 2018 +0200
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Tue May 8 12:41:51 2018 +0200

----------------------------------------------------------------------
 airflow/sensors/named_hive_partition_sensor.py  |  68 +++++-----
 .../sensors/test_named_hive_partition_sensor.py | 130 +++++++++++++++++++
 2 files changed, 169 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b18b437c/airflow/sensors/named_hive_partition_sensor.py
----------------------------------------------------------------------
diff --git a/airflow/sensors/named_hive_partition_sensor.py b/airflow/sensors/named_hive_partition_sensor.py
index a42a360..4a076a3 100644
--- a/airflow/sensors/named_hive_partition_sensor.py
+++ b/airflow/sensors/named_hive_partition_sensor.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -48,6 +48,7 @@ class NamedHivePartitionSensor(BaseSensorOperator):
                  partition_names,
                  metastore_conn_id='metastore_default',
                  poke_interval=60 * 3,
+                 hook=None,
                  *args,
                  **kwargs):
         super(NamedHivePartitionSensor, self).__init__(
@@ -58,37 +59,46 @@ class NamedHivePartitionSensor(BaseSensorOperator):
 
         self.metastore_conn_id = metastore_conn_id
         self.partition_names = partition_names
-        self.next_poke_idx = 0
-
-    @classmethod
-    def parse_partition_name(self, partition):
-        try:
-            schema, table_partition = partition.split('.', 1)
-            table, partition = table_partition.split('/', 1)
-            return schema, table, partition
-        except ValueError as e:
-            raise ValueError('Could not parse ' + partition)
-
-    def poke(self, context):
-        if not hasattr(self, 'hook'):
+        self.hook = hook
+        if self.hook and metastore_conn_id != 'metastore_default':
+            self.log.warning('A hook was passed but a non default'
+                             'metastore_conn_id='
+                             '{} was used'.format(metastore_conn_id))
+
+    @staticmethod
+    def parse_partition_name(partition):
+        first_split = partition.split('.', 1)
+        if len(first_split) == 1:
+            schema = 'default'
+            table_partition = max(first_split)  # poor man first
+        else:
+            schema, table_partition = first_split
+        second_split = table_partition.split('/', 1)
+        if len(second_split) == 1:
+            raise ValueError('Could not parse ' + partition +
+                             'into table, partition')
+        else:
+            table, partition = second_split
+        return schema, table, partition
+
+    def poke_partition(self, partition):
+        if not self.hook:
             from airflow.hooks.hive_hooks import HiveMetastoreHook
             self.hook = HiveMetastoreHook(
                 metastore_conn_id=self.metastore_conn_id)
 
-        def poke_partition(partition):
-
-            schema, table, partition = self.parse_partition_name(partition)
+        schema, table, partition = self.parse_partition_name(partition)
 
-            self.log.info(
-                'Poking for {schema}.{table}/{partition}'.format(**locals())
-            )
-            return self.hook.check_for_named_partition(
-                schema, table, partition)
+        self.log.info(
+            'Poking for {schema}.{table}/{partition}'.format(**locals())
+        )
+        return self.hook.check_for_named_partition(
+            schema, table, partition)
 
-        while self.next_poke_idx < len(self.partition_names):
-            if poke_partition(self.partition_names[self.next_poke_idx]):
-                self.next_poke_idx += 1
-            else:
-                return False
+    def poke(self, context):
 
-        return True
+        self.partition_names = [
+            partition_name for partition_name in self.partition_names
+            if not self.poke_partition(partition_name)
+        ]
+        return not self.partition_names

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b18b437c/tests/sensors/test_named_hive_partition_sensor.py
----------------------------------------------------------------------
diff --git a/tests/sensors/test_named_hive_partition_sensor.py b/tests/sensors/test_named_hive_partition_sensor.py
new file mode 100644
index 0000000..4fef3e0
--- /dev/null
+++ b/tests/sensors/test_named_hive_partition_sensor.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import random
+import unittest
+from datetime import timedelta
+
+from airflow import configuration, DAG, operators
+from airflow.sensors.named_hive_partition_sensor import NamedHivePartitionSensor
+from airflow.utils.timezone import datetime
+from airflow.hooks.hive_hooks import HiveMetastoreHook
+
+configuration.load_test_config()
+
+DEFAULT_DATE = datetime(2015, 1, 1)
+DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
+DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
+
+
+class NamedHivePartitionSensorTests(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        self.dag = DAG('test_dag_id', default_args=args)
+        self.next_day = (DEFAULT_DATE +
+                         timedelta(days=1)).isoformat()[:10]
+        self.database = 'airflow'
+        self.partition_by = 'ds'
+        self.table = 'static_babynames_partitioned'
+        self.hql = """
+                CREATE DATABASE IF NOT EXISTS {{ params.database }};
+                USE {{ params.database }};
+                DROP TABLE IF EXISTS {{ params.table }};
+                CREATE TABLE IF NOT EXISTS {{ params.table }} (
+                    state string,
+                    year string,
+                    name string,
+                    gender string,
+                    num int)
+                PARTITIONED BY ({{ params.partition_by }} string);
+                ALTER TABLE {{ params.table }}
+                ADD PARTITION({{ params.partition_by }}='{{ ds }}');
+                """
+        self.hook = HiveMetastoreHook()
+        t = operators.hive_operator.HiveOperator(
+            task_id='HiveHook_' + str(random.randint(1, 10000)),
+            params={
+                'database': self.database,
+                'table': self.table,
+                'partition_by': self.partition_by
+            },
+            hive_cli_conn_id='beeline_default',
+            hql=self.hql, dag=self.dag)
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+              ignore_ti_state=True)
+
+    def tearDown(self):
+        hook = HiveMetastoreHook()
+        with hook.get_conn() as metastore:
+            metastore.drop_table(self.database, self.table, deleteData=True)
+
+    def test_parse_partition_name_correct(self):
+        schema = 'default'
+        table = 'users'
+        partition = 'ds=2016-01-01/state=IT'
+        name = '{schema}.{table}/{partition}'.format(schema=schema,
+                                                     table=table,
+                                                     partition=partition)
+        parsed_schema, parsed_table, parsed_partition = (
+            NamedHivePartitionSensor.parse_partition_name(name)
+        )
+        self.assertEqual(schema, parsed_schema)
+        self.assertEqual(table, parsed_table)
+        self.assertEqual(partition, parsed_partition)
+
+    def test_parse_partition_name_incorrect(self):
+        name = 'incorrect.name'
+        with self.assertRaises(ValueError):
+            NamedHivePartitionSensor.parse_partition_name(name)
+
+    def test_parse_partition_name_default(self):
+        table = 'users'
+        partition = 'ds=2016-01-01/state=IT'
+        name = '{table}/{partition}'.format(table=table,
+                                            partition=partition)
+        parsed_schema, parsed_table, parsed_partition = (
+            NamedHivePartitionSensor.parse_partition_name(name)
+        )
+        self.assertEqual('default', parsed_schema)
+        self.assertEqual(table, parsed_table)
+        self.assertEqual(partition, parsed_partition)
+
+    def test_poke_existing(self):
+        partitions = ["{}.{}/{}={}".format(self.database,
+                                           self.table,
+                                           self.partition_by,
+                                           DEFAULT_DATE_DS)]
+        sensor = NamedHivePartitionSensor(partition_names=partitions,
+                                          task_id='test_poke_existing',
+                                          poke_interval=1,
+                                          hook=self.hook,
+                                          dag=self.dag)
+        self.assertTrue(sensor.poke(None))
+
+    def test_poke_non_existing(self):
+        partitions = ["{}.{}/{}={}".format(self.database,
+                                           self.table,
+                                           self.partition_by,
+                                           self.next_day)]
+        sensor = NamedHivePartitionSensor(partition_names=partitions,
+                                          task_id='test_poke_non_existing',
+                                          poke_interval=1,
+                                          hook=self.hook,
+                                          dag=self.dag)
+        self.assertFalse(sensor.poke(None))