You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/04 13:54:35 UTC

[GitHub] kaxil closed pull request #4350: [AIRFLOW-3527] Cloud SQL Proxy has shorter path for UNIX socket

kaxil closed pull request #4350: [AIRFLOW-3527] Cloud SQL Proxy has shorter path for UNIX socket
URL: https://github.com/apache/incubator-airflow/pull/4350
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/gcp_sql_hook.py b/airflow/contrib/hooks/gcp_sql_hook.py
index 9872746b7b..e60087968d 100644
--- a/airflow/contrib/hooks/gcp_sql_hook.py
+++ b/airflow/contrib/hooks/gcp_sql_hook.py
@@ -19,8 +19,11 @@
 import errno
 import json
 import os
+import random
 import re
 import shutil
+import string
+
 import socket
 import platform
 import subprocess
@@ -45,6 +48,8 @@
 from airflow.models.connection import Connection
 from airflow.utils.db import provide_session
 
+UNIX_PATH_MAX = 108
+
 NUM_RETRIES = 5
 
 # Time to sleep between active checks of the operation results
@@ -437,8 +442,8 @@ def _download_sql_proxy_if_needed(self):
             download_url = CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL.format(
                 self.sql_proxy_version, system, processor)
         proxy_path_tmp = self.sql_proxy_path + ".tmp"
-        self.log.info("Downloading cloud_sql_proxy from {} to {}".
-                      format(download_url, proxy_path_tmp))
+        self.log.info("Downloading cloud_sql_proxy from %s to %s",
+                      download_url, proxy_path_tmp)
         r = requests.get(download_url, allow_redirects=True)
         # Downloading to .tmp file first to avoid case where partially downloaded
         # binary is used by parallel operator which uses the same fixed binary path
@@ -448,9 +453,8 @@ def _download_sql_proxy_if_needed(self):
             raise AirflowException(
                 "The cloud-sql-proxy could not be downloaded. Status code = {}. "
                 "Reason = {}".format(r.status_code, r.reason))
-        self.log.info("Moving sql_proxy binary from {} to {}".format(
-            proxy_path_tmp, self.sql_proxy_path
-        ))
+        self.log.info("Moving sql_proxy binary from %s to %s",
+                      proxy_path_tmp, self.sql_proxy_path)
         shutil.move(proxy_path_tmp, self.sql_proxy_path)
         os.chmod(self.sql_proxy_path, 0o744)  # Set executable bit
         self.sql_proxy_was_downloaded = True
@@ -468,7 +472,7 @@ def _get_credential_parameters(self, session):
         elif GCP_CREDENTIALS_KEYFILE_DICT in connection.extra_dejson:
             credential_file_content = json.loads(
                 connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT])
-            self.log.info("Saving credentials to {}".format(self.credentials_path))
+            self.log.info("Saving credentials to %s", self.credentials_path)
             with open(self.credentials_path, "w") as f:
                 json.dump(credential_file_content, f)
             credential_params = [
@@ -478,8 +482,8 @@ def _get_credential_parameters(self, session):
         else:
             self.log.info(
                 "The credentials are not supplied by neither key_path nor "
-                "keyfile_dict of the gcp connection {}. Falling back to "
-                "default activated account".format(self.gcp_conn_id))
+                "keyfile_dict of the gcp connection %s. Falling back to "
+                "default activated account", self.gcp_conn_id)
             credential_params = []
 
         if not self.instance_specification:
@@ -509,18 +513,17 @@ def start_proxy(self):
             command_to_run = [self.sql_proxy_path]
             command_to_run.extend(self.command_line_parameters)
             try:
-                self.log.info("Creating directory {}".format(
-                    self.cloud_sql_proxy_socket_directory))
+                self.log.info("Creating directory %s",
+                              self.cloud_sql_proxy_socket_directory)
                 os.makedirs(self.cloud_sql_proxy_socket_directory)
             except OSError:
                 # Needed for python 2 compatibility (exists_ok missing)
                 pass
             command_to_run.extend(self._get_credential_parameters())
-            self.log.info("Running the command: `{}`".format(" ".join(command_to_run)))
+            self.log.info("Running the command: `%s`", " ".join(command_to_run))
             self.sql_proxy_process = Popen(command_to_run,
                                            stdin=PIPE, stdout=PIPE, stderr=PIPE)
-            self.log.info("The pid of cloud_sql_proxy: {}".format(
-                self.sql_proxy_process.pid))
+            self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid)
             while True:
                 line = self.sql_proxy_process.stderr.readline().decode('utf-8')
                 return_code = self.sql_proxy_process.poll()
@@ -548,16 +551,16 @@ def stop_proxy(self):
         if not self.sql_proxy_process:
             raise AirflowException("The sql proxy is not started yet")
         else:
-            self.log.info("Stopping the cloud_sql_proxy pid: {}".format(
-                self.sql_proxy_process.pid))
+            self.log.info("Stopping the cloud_sql_proxy pid: %s",
+                          self.sql_proxy_process.pid)
             self.sql_proxy_process.kill()
             self.sql_proxy_process = None
         # Cleanup!
-        self.log.info("Removing the socket directory: {}".
-                      format(self.cloud_sql_proxy_socket_directory))
+        self.log.info("Removing the socket directory: %s",
+                      self.cloud_sql_proxy_socket_directory)
         shutil.rmtree(self.cloud_sql_proxy_socket_directory, ignore_errors=True)
         if self.sql_proxy_was_downloaded:
-            self.log.info("Removing downloaded proxy: {}".format(self.sql_proxy_path))
+            self.log.info("Removing downloaded proxy: %s", self.sql_proxy_path)
             # Silently ignore if the file has already been removed (concurrency)
             try:
                 os.remove(self.sql_proxy_path)
@@ -565,11 +568,11 @@ def stop_proxy(self):
                 if not e.errno == errno.ENOENT:
                     raise
         else:
-            self.log.info("Skipped removing proxy - it was not downloaded: {}".
-                          format(self.sql_proxy_path))
+            self.log.info("Skipped removing proxy - it was not downloaded: %s",
+                          self.sql_proxy_path)
         if isfile(self.credentials_path):
-            self.log.info("Removing generated credentials file {}".
-                          format(self.credentials_path))
+            self.log.info("Removing generated credentials file %s",
+                          self.credentials_path)
             # Here file cannot be delete by concurrent task (each task has its own copy)
             os.remove(self.credentials_path)
 
@@ -749,18 +752,38 @@ def _validate_inputs(self):
             self._check_ssl_file(self.sslcert, "sslcert")
             self._check_ssl_file(self.sslkey, "sslkey")
             self._check_ssl_file(self.sslrootcert, "sslrootcert")
+        if self.use_proxy and not self.sql_proxy_use_tcp:
+            if self.database_type == 'postgres':
+                suffix = "/.s.PGSQL.5432"
+            else:
+                suffix = ""
+            expected_path = "{}/{}:{}:{}{}".format(
+                self._generate_unique_path(),
+                self.project_id, self.instance,
+                self.database, suffix)
+            if len(expected_path) > UNIX_PATH_MAX:
+                self.log.info("Too long (%s) path: %s", len(expected_path), expected_path)
+                raise AirflowException(
+                    "The UNIX socket path length cannot exceed {} characters "
+                    "on Linux system. Either use shorter instance/database "
+                    "name or switch to TCP connection. "
+                    "The socket path for Cloud SQL proxy is now:"
+                    "{}".format(
+                        UNIX_PATH_MAX, expected_path))
 
-    def _generate_unique_path(self):
+    @staticmethod
+    def _generate_unique_path():
         # We are not using mkdtemp here as the path generated with mkdtemp
         # can be close to 60 characters and there is a limitation in
         # length of socket path to around 100 characters in total.
         # We append project/location/instance to it later and postgres
-        # appends its own prefix, so we chose a shorter "/tmp/{uuid1}" - based
-        # on host name and clock + clock sequence. This should be fairly
-        # sufficient for our needs and should even work if the time is set back.
-        # We are using db_conn_id generated with uuid1 so that connection
-        # id matches the folder - for easier debugging.
-        return "/tmp/" + self.db_conn_id
+        # appends its own prefix, so we chose a shorter "/tmp/[8 random characters]" -
+        random.seed()
+        while True:
+            candidate = "/tmp/" + ''.join(
+                random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
+            if not os.path.exists(candidate):
+                return candidate
 
     @staticmethod
     def _quote(value):
@@ -813,8 +836,8 @@ def _generate_connection_uri(self):
             client_key_file=self._quote(self.sslkey),
             server_ca_file=self._quote(self.sslrootcert)
         )
-        self.log.info("DB connection URI {}".format(connection_uri.replace(
-            quote_plus(self.password), 'XXXXXXXXXXXX')))
+        self.log.info("DB connection URI %s", connection_uri.replace(
+            quote_plus(self.password), 'XXXXXXXXXXXX'))
         return connection_uri
 
     def _get_instance_socket_name(self):
@@ -837,7 +860,7 @@ def create_connection(self, session=None):
         """
         connection = Connection(conn_id=self.db_conn_id)
         uri = self._generate_connection_uri()
-        self.log.info("Creating connection {}".format(self.db_conn_id))
+        self.log.info("Creating connection %s", self.db_conn_id)
         connection.parse_from_uri(uri)
         session.add(connection)
         session.commit()
@@ -850,7 +873,7 @@ def delete_connection(self, session=None):
         :param session: Session of the SQL Alchemy ORM (automatically generated with
                         decorator).
         """
-        self.log.info("Deleting connection {}".format(self.db_conn_id))
+        self.log.info("Deleting connection %s", self.db_conn_id)
         connection = session.query(Connection).filter(
             Connection.conn_id == self.db_conn_id)[0]
         session.delete(connection)
diff --git a/tests/contrib/operators/test_gcp_sql_operator.py b/tests/contrib/operators/test_gcp_sql_operator.py
index 8c2a4aa95c..ac0251aa74 100644
--- a/tests/contrib/operators/test_gcp_sql_operator.py
+++ b/tests/contrib/operators/test_gcp_sql_operator.py
@@ -580,7 +580,8 @@ def test_create_operator_with_wrong_parameters(self,
                                                    get_connections):
         connection = Connection()
         connection.parse_from_uri(
-            "gcpcloudsql://user:password@8.8.8.8:3200/testdb?database_type={database_type}&"
+            "gcpcloudsql://user:password@8.8.8.8:3200/testdb?"
+            "database_type={database_type}&"
             "project_id={project_id}&location={location}&instance={instance_name}&"
             "use_proxy={use_proxy}&use_ssl={use_ssl}".
             format(database_type=database_type,
@@ -797,6 +798,48 @@ def test_create_operator_with_correct_parameters_mysql_tcp(self, get_connections
         self.assertNotEqual(3200, conn.port)
         self.assertEqual('testdb', conn.schema)
 
+    @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
+    def test_create_operator_with_too_long_unix_socket_path(self, get_connections):
+        connection = Connection()
+        connection.parse_from_uri(
+            "gcpcloudsql://user:password@8.8.8.8:3200/testdb?database_type=postgres&"
+            "project_id=example-project&location=europe-west1&"
+            "instance="
+            "test_db_with_long_name_a_bit_above_the_limit_of_UNIX_socket&"
+            "use_proxy=True&sql_proxy_use_tcp=False")
+        get_connections.return_value = [connection]
+        with self.assertRaises(AirflowException) as cm:
+            operator = CloudSqlQueryOperator(
+                sql=['SELECT * FROM TABLE'],
+                task_id='task_id'
+            )
+            operator.cloudsql_db_hook.create_connection()
+        err = cm.exception
+        self.assertIn("The UNIX socket path length cannot exceed", str(err))
+
+    @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
+    def test_create_operator_with_not_too_long_unix_socket_path(self, get_connections):
+        connection = Connection()
+        connection.parse_from_uri(
+            "gcpcloudsql://user:password@8.8.8.8:3200/testdb?database_type=postgres&"
+            "project_id=example-project&location=europe-west1&"
+            "instance="
+            "test_db_with_longname_but_with_limit_of_UNIX_socket_aaaa&"
+            "use_proxy=True&sql_proxy_use_tcp=False")
+        get_connections.return_value = [connection]
+        operator = CloudSqlQueryOperator(
+            sql=['SELECT * FROM TABLE'],
+            task_id='task_id'
+        )
+        operator.cloudsql_db_hook.create_connection()
+        try:
+            db_hook = operator.cloudsql_db_hook.get_database_hook()
+            conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0]
+        finally:
+            operator.cloudsql_db_hook.delete_connection()
+        self.assertEqual('postgres', conn.conn_type)
+        self.assertEqual('testdb', conn.schema)
+
 
 @unittest.skipIf(
     BaseGcpIntegrationTestCase.skip_check(GCP_CLOUDSQL_KEY), SKIP_TEST_WARNING)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services