You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/03/29 21:20:23 UTC
incubator-airflow git commit: [AIRFLOW-858] Configurable database
name for DB operators
Repository: incubator-airflow
Updated Branches:
refs/heads/v1-8-test eb12f0164 -> 5eb33358f
[AIRFLOW-858] Configurable database name for DB operators
Closes #2063 from s7anley/configurable-schema
(cherry picked from commit 94dc7fb0a6bb3c563d9df6566cd52a59bd0c4629)
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/5eb33358
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/5eb33358
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/5eb33358
Branch: refs/heads/v1-8-test
Commit: 5eb33358f62a13192e537296becc315476112afb
Parents: eb12f01
Author: J�n Ko\u0161\u010do <3k...@gmail.com>
Authored: Sun Feb 12 15:43:41 2017 -0500
Committer: Chris Riccomini <cr...@apache.org>
Committed: Wed Mar 29 14:19:19 2017 -0700
----------------------------------------------------------------------
airflow/hooks/mssql_hook.py | 10 +++++--
airflow/hooks/mysql_hook.py | 15 ++++++----
airflow/hooks/postgres_hook.py | 4 +--
airflow/operators/mssql_operator.py | 11 ++++++--
airflow/operators/mysql_operator.py | 8 ++++--
airflow/operators/postgres_operator.py | 7 ++++-
tests/operators/operators.py | 43 +++++++++++++++++++++++++++++
7 files changed, 81 insertions(+), 17 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/mssql_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py
index 1450967..99a4c82 100644
--- a/airflow/hooks/mssql_hook.py
+++ b/airflow/hooks/mssql_hook.py
@@ -18,14 +18,18 @@ from airflow.hooks.dbapi_hook import DbApiHook
class MsSqlHook(DbApiHook):
- '''
+ """
Interact with Microsoft SQL Server.
- '''
+ """
conn_name_attr = 'mssql_conn_id'
default_conn_name = 'mssql_default'
supports_autocommit = True
+ def __init__(self, *args, **kwargs):
+ super(MsSqlHook, self).__init__(*args, **kwargs)
+ self.schema = kwargs.pop("schema", None)
+
def get_conn(self):
"""
Returns a mssql connection object
@@ -35,7 +39,7 @@ class MsSqlHook(DbApiHook):
server=conn.host,
user=conn.login,
password=conn.password,
- database=conn.schema,
+ database=self.schema or conn.schema,
port=conn.port)
return conn
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/mysql_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py
index e4f9533..bf1a721 100644
--- a/airflow/hooks/mysql_hook.py
+++ b/airflow/hooks/mysql_hook.py
@@ -19,18 +19,22 @@ from airflow.hooks.dbapi_hook import DbApiHook
class MySqlHook(DbApiHook):
- '''
+ """
Interact with MySQL.
You can specify charset in the extra field of your connection
as ``{"charset": "utf8"}``. Also you can choose cursor as
``{"cursor": "SSCursor"}``. Refer to the MySQLdb.cursors for more details.
- '''
+ """
conn_name_attr = 'mysql_conn_id'
default_conn_name = 'mysql_default'
supports_autocommit = True
+ def __init__(self, *args, **kwargs):
+ super(MySqlHook, self).__init__(*args, **kwargs)
+ self.schema = kwargs.pop("schema", None)
+
def get_conn(self):
"""
Returns a mysql connection object
@@ -38,17 +42,16 @@ class MySqlHook(DbApiHook):
conn = self.get_connection(self.mysql_conn_id)
conn_config = {
"user": conn.login,
- "passwd": conn.password or ''
+ "passwd": conn.password or '',
+ "host": conn.host or 'localhost',
+ "db": self.schema or conn.schema or ''
}
- conn_config["host"] = conn.host or 'localhost'
if not conn.port:
conn_config["port"] = 3306
else:
conn_config["port"] = int(conn.port)
- conn_config["db"] = conn.schema or ''
-
if conn.extra_dejson.get('charset', False):
conn_config["charset"] = conn.extra_dejson["charset"]
if (conn_config["charset"]).lower() == 'utf8' or\
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/hooks/postgres_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index 584930d..4b460c1 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -19,11 +19,11 @@ from airflow.hooks.dbapi_hook import DbApiHook
class PostgresHook(DbApiHook):
- '''
+ """
Interact with Postgres.
You can specify ssl parameters in the extra field of your connection
as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``.
- '''
+ """
conn_name_attr = 'postgres_conn_id'
default_conn_name = 'postgres_default'
supports_autocommit = True
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/mssql_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py
index 0590454..0f0cd63 100644
--- a/airflow/operators/mssql_operator.py
+++ b/airflow/operators/mssql_operator.py
@@ -27,6 +27,8 @@ class MsSqlOperator(BaseOperator):
:param sql: the sql code to be executed
:type sql: string or string pointing to a template file.
File must have a '.sql' extensions.
+ :param database: name of database which overwrite defined one in connection
+ :type database: string
"""
template_fields = ('sql',)
@@ -36,14 +38,17 @@ class MsSqlOperator(BaseOperator):
@apply_defaults
def __init__(
self, sql, mssql_conn_id='mssql_default', parameters=None,
- autocommit=False, *args, **kwargs):
+ autocommit=False, database=None, *args, **kwargs):
super(MsSqlOperator, self).__init__(*args, **kwargs)
self.mssql_conn_id = mssql_conn_id
self.sql = sql
self.parameters = parameters
self.autocommit = autocommit
+ self.database = database
def execute(self, context):
logging.info('Executing: ' + str(self.sql))
- hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
- hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+ hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id,
+ schema=self.database)
+ hook.run(self.sql, autocommit=self.autocommit,
+ parameters=self.parameters)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/mysql_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py
index b3a3c73..156ada8 100644
--- a/airflow/operators/mysql_operator.py
+++ b/airflow/operators/mysql_operator.py
@@ -29,6 +29,8 @@ class MySqlOperator(BaseOperator):
:type sql: Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
Template reference are recognized by str ending in '.sql'
+ :param database: name of database which overwrite defined one in connection
+ :type database: string
"""
template_fields = ('sql',)
@@ -38,16 +40,18 @@ class MySqlOperator(BaseOperator):
@apply_defaults
def __init__(
self, sql, mysql_conn_id='mysql_default', parameters=None,
- autocommit=False, *args, **kwargs):
+ autocommit=False, database=None, *args, **kwargs):
super(MySqlOperator, self).__init__(*args, **kwargs)
self.mysql_conn_id = mysql_conn_id
self.sql = sql
self.autocommit = autocommit
self.parameters = parameters
+ self.database = database
def execute(self, context):
logging.info('Executing: ' + str(self.sql))
- hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ hook = MySqlHook(mysql_conn_id=self.mysql_conn_id,
+ schema=self.database)
hook.run(
self.sql,
autocommit=self.autocommit,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/airflow/operators/postgres_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py
index c4f56a4..0de5aa5 100644
--- a/airflow/operators/postgres_operator.py
+++ b/airflow/operators/postgres_operator.py
@@ -29,6 +29,8 @@ class PostgresOperator(BaseOperator):
:type sql: Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
Template reference are recognized by str ending in '.sql'
+ :param database: name of database which overwrite defined one in connection
+ :type database: string
"""
template_fields = ('sql',)
@@ -40,14 +42,17 @@ class PostgresOperator(BaseOperator):
self, sql,
postgres_conn_id='postgres_default', autocommit=False,
parameters=None,
+ database=None,
*args, **kwargs):
super(PostgresOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.postgres_conn_id = postgres_conn_id
self.autocommit = autocommit
self.parameters = parameters
+ self.database = database
def execute(self, context):
logging.info('Executing: ' + str(self.sql))
- self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
+ self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id,
+ schema=self.database)
self.hook.run(self.sql, self.autocommit, parameters=self.parameters)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/5eb33358/tests/operators/operators.py
----------------------------------------------------------------------
diff --git a/tests/operators/operators.py b/tests/operators/operators.py
index 7aaf12e..19901ae 100644
--- a/tests/operators/operators.py
+++ b/tests/operators/operators.py
@@ -114,6 +114,27 @@ class MySqlTest(unittest.TestCase):
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ def test_overwrite_schema(self):
+ """
+ Verifies option to overwrite connection schema
+ """
+ import airflow.operators.mysql_operator
+
+ sql = "SELECT 1;"
+ t = operators.mysql_operator.MySqlOperator(
+ task_id='test_mysql_operator_test_schema_overwrite',
+ sql=sql,
+ dag=self.dag,
+ database="foobar",
+ )
+
+ from _mysql_exceptions import OperationalError
+ try:
+ t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+ ignore_ti_state=True)
+ except OperationalError as e:
+ assert "Unknown database 'foobar'" in str(e)
+
@skipUnlessImported('airflow.operators.postgres_operator', 'PostgresOperator')
class PostgresTest(unittest.TestCase):
@@ -193,6 +214,28 @@ class PostgresTest(unittest.TestCase):
autocommit=True)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+ def test_overwrite_schema(self):
+ """
+ Verifies option to overwrite connection schema
+ """
+ import airflow.operators.postgres_operator
+
+ sql = "SELECT 1;"
+ t = operators.postgres_operator.PostgresOperator(
+ task_id='postgres_operator_test_schema_overwrite',
+ sql=sql,
+ dag=self.dag,
+ autocommit=True,
+ database="foobar",
+ )
+
+ from psycopg2._psycopg import OperationalError
+ try:
+ t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+ ignore_ti_state=True)
+ except OperationalError as e:
+ assert 'database "foobar" does not exist' in str(e)
+
@skipUnlessImported('airflow.operators.hive_operator', 'HiveOperator')
@skipUnlessImported('airflow.operators.postgres_operator', 'PostgresOperator')