You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by xu...@apache.org on 2016/11/16 05:23:44 UTC
incubator-airflow git commit: [AIRFLOW-343] Fix schema plumbing in
HiveServer2Hook
Repository: incubator-airflow
Updated Branches:
refs/heads/master 664e63a72 -> 448e06f69
[AIRFLOW-343] Fix schema plumbing in HiveServer2Hook
Allow HiveServer2Hook to be used on other databases than default.
Testing Done:
- Added new unit test coverage in HiveOperator to
cover these changes. (This is the established
testing norm for HiveHook).
Closes #1743 from gr8routdoors/airflow_343
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/448e06f6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/448e06f6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/448e06f6
Branch: refs/heads/master
Commit: 448e06f697589de5bc19f3af76f736ae043a6e6b
Parents: 664e63a
Author: Devon Berry <de...@livingsocial.com>
Authored: Wed Nov 16 00:21:38 2016 -0500
Committer: Li Xuanji <xu...@gmail.com>
Committed: Wed Nov 16 00:21:43 2016 -0500
----------------------------------------------------------------------
airflow/hooks/hive_hooks.py | 8 ++--
tests/operators/hive_operator.py | 86 +++++++++++++++++++++++++++++++++--
2 files changed, 86 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/448e06f6/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index a9fac48..0bada7a 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -573,7 +573,7 @@ class HiveServer2Hook(BaseHook):
def __init__(self, hiveserver2_conn_id='hiveserver2_default'):
self.hiveserver2_conn_id = hiveserver2_conn_id
- def get_conn(self):
+ def get_conn(self, schema=None):
db = self.get_connection(self.hiveserver2_conn_id)
auth_mechanism = db.extra_dejson.get('authMechanism', 'PLAIN')
kerberos_service_name = None
@@ -594,11 +594,11 @@ class HiveServer2Hook(BaseHook):
auth_mechanism=auth_mechanism,
kerberos_service_name=kerberos_service_name,
user=db.login,
- database=db.schema or 'default')
+ database=schema or db.schema or 'default')
def get_results(self, hql, schema='default', arraysize=1000):
from impala.error import ProgrammingError
- with self.get_conn() as conn:
+ with self.get_conn(schema) as conn:
if isinstance(hql, basestring):
hql = [hql]
results = {
@@ -633,7 +633,7 @@ class HiveServer2Hook(BaseHook):
output_header=True,
fetch_size=1000):
schema = schema or 'default'
- with self.get_conn() as conn:
+ with self.get_conn(schema) as conn:
with conn.cursor() as cur:
logging.info("Running query: " + hql)
cur.execute(hql)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/448e06f6/tests/operators/hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/hive_operator.py b/tests/operators/hive_operator.py
index fd5e096..9f90999 100644
--- a/tests/operators/hive_operator.py
+++ b/tests/operators/hive_operator.py
@@ -17,15 +17,13 @@ from __future__ import print_function
import datetime
import os
import unittest
+import mock
+import nose
import six
from airflow import DAG, configuration, operators, utils
configuration.load_test_config()
-import os
-import unittest
-import nose
-
DEFAULT_DATE = datetime.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
@@ -40,6 +38,7 @@ if 'AIRFLOW_RUNALL_TESTS' in os.environ:
class HiveServer2Test(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
+ self.nondefault_schema = "nondefault"
def test_select_conn(self):
from airflow.hooks.hive_hooks import HiveServer2Hook
@@ -68,6 +67,85 @@ if 'AIRFLOW_RUNALL_TESTS' in os.environ:
hook = HiveServer2Hook()
hook.to_csv(hql=sql, csv_filepath="/tmp/test_to_csv")
+ def connect_mock(host, port, auth_mechanism, kerberos_service_name, user, database):
+ self.assertEqual(database, self.nondefault_schema)
+
+ @patch('HiveServer2Hook.connect', return_value="foo")
+ def test_select_conn_with_schema(self, connect_mock):
+ from airflow.hooks.hive_hooks import HiveServer2Hook
+
+ # Configure
+ hook = HiveServer2Hook()
+
+ # Run
+ hook.get_conn(self.nondefault_schema)
+
+ # Verify
+ assert connect_mock.called
+ (args, kwargs) = connect_mock.call_args_list[0]
+ assert kwargs['database'] == self.nondefault_schema
+
+ def test_get_results_with_schema(self):
+ from airflow.hooks.hive_hooks import HiveServer2Hook
+ from unittest.mock import MagicMock
+
+ # Configure
+ sql = "select 1"
+ schema = "notdefault"
+ hook = HiveServer2Hook()
+ cursor_mock = MagicMock(
+ __enter__ = cursor_mock,
+ __exit__ = None,
+ execute = None,
+ fetchall = [],
+ )
+ get_conn_mock = MagicMock(
+ __enter__ = get_conn_mock,
+ __exit__ = None,
+ cursor = cursor_mock,
+ )
+ hook.get_conn = get_conn_mock
+
+ # Run
+ hook.get_results(sql, schema)
+
+ # Verify
+ get_conn_mock.assert_called_with(self.nondefault_schema)
+
+ @patch('HiveServer2Hook.get_results', return_value={data:[]})
+ def test_get_records_with_schema(self, get_results_mock):
+ from airflow.hooks.hive_hooks import HiveServer2Hook
+
+ # Configure
+ sql = "select 1"
+ hook = HiveServer2Hook()
+
+ # Run
+ hook.get_records(sql, self.nondefault_schema)
+
+ # Verify
+ assert connect_mock.called
+ (args, kwargs) = connect_mock.call_args_list[0]
+ assert args[0] == sql
+ assert kwargs['schema'] == self.nondefault_schema
+
+ @patch('HiveServer2Hook.get_results', return_value={data:[]})
+ def test_get_pandas_df_with_schema(self, get_results_mock):
+ from airflow.hooks.hive_hooks import HiveServer2Hook
+
+ # Configure
+ sql = "select 1"
+ hook = HiveServer2Hook()
+
+ # Run
+ hook.get_pandas_df(sql, schema)
+
+ # Verify
+ assert connect_mock.called
+ (args, kwargs) = connect_mock.call_args_list[0]
+ assert args[0] == sql
+ assert kwargs['schema'] == self.nondefault_schema
+
class HivePrestoTest(unittest.TestCase):
def setUp(self):