You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by da...@apache.org on 2018/03/08 00:12:43 UTC

incubator-airflow git commit: [AIRFLOW-2150] Use lighter call in HiveMetastoreHook().max_partition()

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 0f9f4605f -> b8c2cea36


[AIRFLOW-2150] Use lighter call in HiveMetastoreHook().max_partition()

Call self.metastore.get_partition_names() instead of
self.metastore.get_partitions(), which is extremely expensive for
large tables, in HiveMetastoreHook().max_partition().

Closes #3082 from
yrqls21/kevin_yang_fix_hive_max_partition


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b8c2cea3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b8c2cea3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b8c2cea3

Branch: refs/heads/master
Commit: b8c2cea36299d6a3264d8bb1dc5a3995732b8855
Parents: 0f9f460
Author: Kevin Yang <ke...@airbnb.com>
Authored: Wed Mar 7 16:12:14 2018 -0800
Committer: Dan Davydov <da...@airbnb.com>
Committed: Wed Mar 7 16:12:18 2018 -0800

----------------------------------------------------------------------
 airflow/hooks/hive_hooks.py   | 64 ++++++++++++++++++++++++++++++--------
 airflow/macros/hive.py        |  2 +-
 tests/hooks/test_hive_hook.py | 39 +++++++++++++++++++++++
 3 files changed, 91 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b8c2cea3/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index cd7319d..128be41 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -429,9 +429,11 @@ class HiveCliHook(BaseHook):
 
 
 class HiveMetastoreHook(BaseHook):
-
     """ Wrapper to interact with the Hive Metastore"""
 
+    # java short max val
+    MAX_PART_COUNT = 32767
+
     def __init__(self, metastore_conn_id='metastore_default'):
         self.metastore_conn = self.get_connection(metastore_conn_id)
         self.metastore = self.get_metastore_client()
@@ -601,16 +603,46 @@ class HiveMetastoreHook(BaseHook):
             if filter:
                 parts = self.metastore.get_partitions_by_filter(
                     db_name=schema, tbl_name=table_name,
-                    filter=filter, max_parts=32767)
+                    filter=filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
             else:
                 parts = self.metastore.get_partitions(
-                    db_name=schema, tbl_name=table_name, max_parts=32767)
+                    db_name=schema, tbl_name=table_name,
+                    max_parts=HiveMetastoreHook.MAX_PART_COUNT)
 
             self.metastore._oprot.trans.close()
             pnames = [p.name for p in table.partitionKeys]
             return [dict(zip(pnames, p.values)) for p in parts]
 
-    def max_partition(self, schema, table_name, field=None, filter=None):
+    @staticmethod
+    def _get_max_partition_from_part_names(part_names, key_name):
+        """
+        Helper method to get max partition from part names. Works only
+        when partition format follows '{key}={value}' and key_name is name of
+        the only partition key.
+        :param part_names: list of partition names
+        :type part_names: list
+        :param key_name: partition key name
+        :type key_name: str
+        :return: Max partition or None if part_names is empty.
+        """
+        if not part_names:
+            return None
+
+        prefix = key_name + '='
+        prefix_len = len(key_name) + 1
+        max_val = None
+        for part_name in part_names:
+            if part_name.startswith(prefix):
+                if max_val is None:
+                    max_val = part_name[prefix_len:]
+                else:
+                    max_val = max(max_val, part_name[prefix_len:])
+            else:
+                raise AirflowException(
+                    "Partition name mal-formatted: {}".format(part_name))
+        return max_val
+
+    def max_partition(self, schema, table_name, field=None):
         """
         Returns the maximum value for all partitions in a table. Works only
         for tables that have a single partition key. For subpartitioned
@@ -621,17 +653,23 @@ class HiveMetastoreHook(BaseHook):
         >>> hh.max_partition(schema='airflow', table_name=t)
         '2015-01-01'
         """
-        parts = self.get_partitions(schema, table_name, filter)
-        if not parts:
-            return None
-        elif len(parts[0]) == 1:
-            field = list(parts[0].keys())[0]
-        elif not field:
+        self.metastore._oprot.trans.open()
+        table = self.metastore.get_table(dbname=schema, tbl_name=table_name)
+        if len(table.partitionKeys) != 1:
             raise AirflowException(
-                "Please specify the field you want the max "
-                "value for")
+                "The table isn't partitioned by a single partition key")
+
+        key_name = table.partitionKeys[0].name
+        if field is not None and key_name != field:
+            raise AirflowException("Provided field is not the partition key")
+
+        part_names = \
+            self.metastore.get_partition_names(schema,
+                                               table_name,
+                                               max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+        self.metastore._oprot.trans.close()
 
-        return max([p[field] for p in parts])
+        return HiveMetastoreHook._get_max_partition_from_part_names(part_names, key_name)
 
     def table_exists(self, table_name, db='default'):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b8c2cea3/airflow/macros/hive.py
----------------------------------------------------------------------
diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py
index c68c293..ef80fc6 100644
--- a/airflow/macros/hive.py
+++ b/airflow/macros/hive.py
@@ -44,7 +44,7 @@ def max_partition(
         schema, table = table.split('.')
     hh = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
     return hh.max_partition(
-        schema=schema, table_name=table, field=field, filter=filter)
+        schema=schema, table_name=table, field=field)
 
 
 def _closest_date(target_dt, date_list, before_target=None):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b8c2cea3/tests/hooks/test_hive_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py
new file mode 100644
index 0000000..c7da8e5
--- /dev/null
+++ b/tests/hooks/test_hive_hook.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.hive_hooks import HiveMetastoreHook
+
+
+class TestHiveMetastoreHook(unittest.TestCase):
+    def test_get_max_partition_from_empty_part_names(self):
+        max_partition = \
+            HiveMetastoreHook._get_max_partition_from_part_names([], 'some_key')
+        self.assertIsNone(max_partition)
+
+    def test_get_max_partition_from_mal_formatted_part_names(self):
+        with self.assertRaises(AirflowException):
+            HiveMetastoreHook._get_max_partition_from_part_names(
+                ['bad_partition_name'], 'some_key')
+
+    def test_get_max_partition_from_mal_valid_part_names(self):
+        max_partition = \
+            HiveMetastoreHook._get_max_partition_from_part_names(['some_key=value1',
+                                                                  'some_key=value2',
+                                                                  'some_key=value3'],
+                                                                 'some_key')
+        self.assertEqual(max_partition, 'value3')