You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/09 20:36:41 UTC

[GitHub] kaxil closed pull request #4451: [AIRFLOW-3630] Cleanup of generated connection is done also on exception

kaxil closed pull request #4451: [AIRFLOW-3630] Cleanup of generated connection is done also on exception
URL: https://github.com/apache/airflow/pull/4451
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/gcp_sql_hook.py b/airflow/contrib/hooks/gcp_sql_hook.py
index e60087968d..af9ad57003 100644
--- a/airflow/contrib/hooks/gcp_sql_hook.py
+++ b/airflow/contrib/hooks/gcp_sql_hook.py
@@ -748,6 +748,8 @@ def _validate_inputs(self):
             raise AirflowException("Cloud SQL Proxy does not support SSL connections."
                                    " SSL is not needed as Cloud SQL Proxy "
                                    "provides encryption on its own")
+
+    def validate_ssl_certs(self):
         if self.use_ssl:
             self._check_ssl_file(self.sslcert, "sslcert")
             self._check_ssl_file(self.sslkey, "sslkey")
@@ -913,8 +915,9 @@ def cleanup_database_hook(self):
         Clean up database hook after it was used.
         """
         if self.database_type == 'postgres':
-            for output in self.db_hook.conn.notices:
-                self.log.info(output)
+            if self.db_hook.conn and self.db_hook.conn.notices:
+                for output in self.db_hook.conn.notices:
+                    self.log.info(output)
 
     def reserve_free_tcp_port(self):
         """
diff --git a/airflow/contrib/operators/gcp_sql_operator.py b/airflow/contrib/operators/gcp_sql_operator.py
index abdefb5190..e5c6cb8dbc 100644
--- a/airflow/contrib/operators/gcp_sql_operator.py
+++ b/airflow/contrib/operators/gcp_sql_operator.py
@@ -704,30 +704,30 @@ def __init__(self,
         self.cloud_sql_proxy_runner = None
         self.database_hook = None
 
-    def pre_execute(self, context):
-        self.cloudsql_db_hook.create_connection()
-        self.database_hook = self.cloudsql_db_hook.get_database_hook()
-        if self.cloudsql_db_hook.use_proxy:
-            self.cloud_sql_proxy_runner = self.cloudsql_db_hook.get_sqlproxy_runner()
-            self.cloudsql_db_hook.free_reserved_port()
-            # There is very, very slim chance that the socket will be taken over
-            # here by another bind(0). It's quite unlikely to happen though!
-            self.cloud_sql_proxy_runner.start_proxy()
-
     def execute(self, context):
-        self.log.info('Executing: "%s"', self.sql)
-        self.database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
-
-    def post_execute(self, context, result=None):
-        # Make sure that all the cleanups happen, no matter if there are some
-        # exceptions thrown
+        self.cloudsql_db_hook.validate_ssl_certs()
+        self.cloudsql_db_hook.create_connection()
         try:
-            self.cloudsql_db_hook.cleanup_database_hook()
-        finally:
+            self.database_hook = self.cloudsql_db_hook.get_database_hook()
             try:
-                if self.cloud_sql_proxy_runner:
-                    self.cloud_sql_proxy_runner.stop_proxy()
-                    self.cloud_sql_proxy_runner = None
+                try:
+                    if self.cloudsql_db_hook.use_proxy:
+                        self.cloud_sql_proxy_runner = self.cloudsql_db_hook.\
+                            get_sqlproxy_runner()
+                        self.cloudsql_db_hook.free_reserved_port()
+                        # There is very, very slim chance that the socket will
+                        # be taken over here by another bind(0).
+                        # It's quite unlikely to happen though!
+                        self.cloud_sql_proxy_runner.start_proxy()
+                    self.log.info('Executing: "%s"', self.sql)
+                    self.database_hook.run(self.sql, self.autocommit,
+                                           parameters=self.parameters)
+                finally:
+                    if self.cloud_sql_proxy_runner:
+                        self.cloud_sql_proxy_runner.stop_proxy()
+                        self.cloud_sql_proxy_runner = None
             finally:
-                self.cloudsql_db_hook.delete_connection()
-                self.cloudsql_db_hook = None
+                self.cloudsql_db_hook.cleanup_database_hook()
+        finally:
+            self.cloudsql_db_hook.delete_connection()
+            self.cloudsql_db_hook = None
diff --git a/tests/contrib/operators/test_gcp_sql_operator.py b/tests/contrib/operators/test_gcp_sql_operator.py
index ac0251aa74..646b2e771f 100644
--- a/tests/contrib/operators/test_gcp_sql_operator.py
+++ b/tests/contrib/operators/test_gcp_sql_operator.py
@@ -592,10 +592,11 @@ def test_create_operator_with_wrong_parameters(self,
                    use_ssl=use_ssl))
         get_connections.return_value = [connection]
         with self.assertRaises(AirflowException) as cm:
-            CloudSqlQueryOperator(
+            op = CloudSqlQueryOperator(
                 sql=sql,
                 task_id='task_id'
             )
+            op.execute(None)
         err = cm.exception
         self.assertIn(message, str(err))
 
@@ -809,11 +810,11 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connections):
             "use_proxy=True&sql_proxy_use_tcp=False")
         get_connections.return_value = [connection]
         with self.assertRaises(AirflowException) as cm:
-            operator = CloudSqlQueryOperator(
+            op = CloudSqlQueryOperator(
                 sql=['SELECT * FROM TABLE'],
                 task_id='task_id'
             )
-            operator.cloudsql_db_hook.create_connection()
+            op.execute(None)
         err = cm.exception
         self.assertIn("The UNIX socket path length cannot exceed", str(err))
 
@@ -840,6 +841,39 @@ def test_create_operator_with_not_too_long_unix_socket_path(self, get_connection
         self.assertEqual('postgres', conn.conn_type)
         self.assertEqual('testdb', conn.schema)
 
+    @mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
+                "delete_connection")
+    @mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
+                "get_connection")
+    @mock.patch("airflow.hooks.mysql_hook.MySqlHook.run")
+    @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
+    def test_cloudsql_hook_delete_connection_on_exception(
+            self, get_connections, run, get_connection, delete_connection):
+        connection = Connection()
+        connection.parse_from_uri(
+            "gcpcloudsql://user:password@8.8.8.8:3200/testdb?database_type=mysql&"
+            "project_id=example-project&location=europe-west1&instance=testdb&"
+            "use_proxy=False")
+        get_connection.return_value = connection
+
+        db_connection = Connection()
+        db_connection.host = "8.8.8.8"
+        db_connection.set_extra(json.dumps({"project_id": "example-project",
+                                            "location": "europe-west1",
+                                            "instance": "testdb",
+                                            "database_type": "mysql"}))
+        get_connections.return_value = [db_connection]
+        run.side_effect = Exception("Exception when running a query")
+        operator = CloudSqlQueryOperator(
+            sql=['SELECT * FROM TABLE'],
+            task_id='task_id'
+        )
+        with self.assertRaises(Exception) as cm:
+            operator.execute(None)
+        err = cm.exception
+        self.assertEqual("Exception when running a query", str(err))
+        delete_connection.assert_called_once_with()
+
 
 @unittest.skipIf(
     BaseGcpIntegrationTestCase.skip_check(GCP_CLOUDSQL_KEY), SKIP_TEST_WARNING)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services