You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by xu...@apache.org on 2016/11/16 05:23:44 UTC

incubator-airflow git commit: [AIRFLOW-343] Fix schema plumbing in HiveServer2Hook

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 664e63a72 -> 448e06f69


[AIRFLOW-343] Fix schema plumbing in HiveServer2Hook

Allow HiveServer2Hook to be used on other databases than default.

Testing Done:
- Added new unit test coverage in HiveOperator to
cover these changes.  (This is the established
testing norm for HiveHook).

Closes #1743 from gr8routdoors/airflow_343


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/448e06f6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/448e06f6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/448e06f6

Branch: refs/heads/master
Commit: 448e06f697589de5bc19f3af76f736ae043a6e6b
Parents: 664e63a
Author: Devon Berry <de...@livingsocial.com>
Authored: Wed Nov 16 00:21:38 2016 -0500
Committer: Li Xuanji <xu...@gmail.com>
Committed: Wed Nov 16 00:21:43 2016 -0500

----------------------------------------------------------------------
 airflow/hooks/hive_hooks.py      |  8 ++--
 tests/operators/hive_operator.py | 86 +++++++++++++++++++++++++++++++++--
 2 files changed, 86 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/448e06f6/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index a9fac48..0bada7a 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -573,7 +573,7 @@ class HiveServer2Hook(BaseHook):
     def __init__(self, hiveserver2_conn_id='hiveserver2_default'):
         self.hiveserver2_conn_id = hiveserver2_conn_id
 
-    def get_conn(self):
+    def get_conn(self, schema=None):
         db = self.get_connection(self.hiveserver2_conn_id)
         auth_mechanism = db.extra_dejson.get('authMechanism', 'PLAIN')
         kerberos_service_name = None
@@ -594,11 +594,11 @@ class HiveServer2Hook(BaseHook):
             auth_mechanism=auth_mechanism,
             kerberos_service_name=kerberos_service_name,
             user=db.login,
-            database=db.schema or 'default')
+            database=schema or db.schema or 'default')
 
     def get_results(self, hql, schema='default', arraysize=1000):
         from impala.error import ProgrammingError
-        with self.get_conn() as conn:
+        with self.get_conn(schema) as conn:
             if isinstance(hql, basestring):
                 hql = [hql]
             results = {
@@ -633,7 +633,7 @@ class HiveServer2Hook(BaseHook):
             output_header=True,
             fetch_size=1000):
         schema = schema or 'default'
-        with self.get_conn() as conn:
+        with self.get_conn(schema) as conn:
             with conn.cursor() as cur:
                 logging.info("Running query: " + hql)
                 cur.execute(hql)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/448e06f6/tests/operators/hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/hive_operator.py b/tests/operators/hive_operator.py
index fd5e096..9f90999 100644
--- a/tests/operators/hive_operator.py
+++ b/tests/operators/hive_operator.py
@@ -17,15 +17,13 @@ from __future__ import print_function
 import datetime
 import os
 import unittest
+import mock
+import nose
 import six
 
 from airflow import DAG, configuration, operators, utils
 configuration.load_test_config()
 
-import os
-import unittest
-import nose
-
 
 DEFAULT_DATE = datetime.datetime(2015, 1, 1)
 DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
@@ -40,6 +38,7 @@ if 'AIRFLOW_RUNALL_TESTS' in os.environ:
     class HiveServer2Test(unittest.TestCase):
         def setUp(self):
             configuration.load_test_config()
+            self.nondefault_schema = "nondefault"
 
         def test_select_conn(self):
             from airflow.hooks.hive_hooks import HiveServer2Hook
@@ -68,6 +67,85 @@ if 'AIRFLOW_RUNALL_TESTS' in os.environ:
             hook = HiveServer2Hook()
             hook.to_csv(hql=sql, csv_filepath="/tmp/test_to_csv")
 
+        def connect_mock(host, port, auth_mechanism, kerberos_service_name, user, database):
+            self.assertEqual(database, self.nondefault_schema)
+
+        @patch('HiveServer2Hook.connect', return_value="foo")
+        def test_select_conn_with_schema(self, connect_mock):
+            from airflow.hooks.hive_hooks import HiveServer2Hook
+
+            # Configure
+            hook = HiveServer2Hook()
+
+            # Run
+            hook.get_conn(self.nondefault_schema)
+
+            # Verify
+            assert connect_mock.called
+            (args, kwargs) = connect_mock.call_args_list[0]
+            assert kwargs['database'] == self.nondefault_schema
+
+        def test_get_results_with_schema(self):
+            from airflow.hooks.hive_hooks import HiveServer2Hook
+            from unittest.mock import MagicMock
+
+            # Configure
+            sql = "select 1"
+            schema = "notdefault"
+            hook = HiveServer2Hook()
+            cursor_mock = MagicMock(
+                __enter__ = cursor_mock,
+                __exit__ = None,
+                execute = None,
+                fetchall = [],
+            )
+            get_conn_mock = MagicMock(
+                __enter__ = get_conn_mock,
+                __exit__ = None,
+                cursor = cursor_mock,
+            )
+            hook.get_conn = get_conn_mock
+
+            # Run
+            hook.get_results(sql, schema)
+
+            # Verify
+            get_conn_mock.assert_called_with(self.nondefault_schema)
+
+        @patch('HiveServer2Hook.get_results', return_value={data:[]})
+        def test_get_records_with_schema(self, get_results_mock):
+            from airflow.hooks.hive_hooks import HiveServer2Hook
+
+            # Configure
+            sql = "select 1"
+            hook = HiveServer2Hook()
+
+            # Run
+            hook.get_records(sql, self.nondefault_schema)
+
+            # Verify
+            assert connect_mock.called
+            (args, kwargs) = connect_mock.call_args_list[0]
+            assert args[0] == sql
+            assert kwargs['schema'] == self.nondefault_schema
+
+        @patch('HiveServer2Hook.get_results', return_value={data:[]})
+        def test_get_pandas_df_with_schema(self, get_results_mock):
+            from airflow.hooks.hive_hooks import HiveServer2Hook
+
+            # Configure
+            sql = "select 1"
+            hook = HiveServer2Hook()
+
+            # Run
+            hook.get_pandas_df(sql, schema)
+
+            # Verify
+            assert connect_mock.called
+            (args, kwargs) = connect_mock.call_args_list[0]
+            assert args[0] == sql
+            assert kwargs['schema'] == self.nondefault_schema
+
     class HivePrestoTest(unittest.TestCase):
 
         def setUp(self):