You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/18 14:25:26 UTC

[airflow] branch master updated: Simplified GCSTaskHandler configuration (#10365)

This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 083c3c1  Simplified GCSTaskHandler configuration (#10365)
083c3c1 is described below

commit 083c3c129bc3458d410f5ff37d7f5a9a7ad548b7
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Aug 18 16:24:26 2020 +0200

    Simplified GCSTaskHandler configuration (#10365)
---
 UPDATING.md                                        |  14 ++
 airflow/config_templates/airflow_local_settings.py |   6 +-
 airflow/config_templates/config.yml                |   4 +-
 airflow/config_templates/default_airflow.cfg       |   4 +-
 .../providers/google/cloud/log/gcs_task_handler.py | 115 +++++++------
 docs/howto/write-logs.rst                          |   9 +-
 .../google/cloud/log/test_gcs_task_handler.py      | 188 +++++++++++++--------
 .../log/test_stackdriver_task_handler_system.py    |   2 +-
 8 files changed, 215 insertions(+), 127 deletions(-)

diff --git a/UPDATING.md b/UPDATING.md
index 7392acc..4279b8c 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -561,6 +561,20 @@ better handle the case when a DAG file has multiple DAGs.
 Sentry is disabled by default. To enable these integrations, you need set ``sentry_on`` option
 in ``[sentry]`` section to ``"True"``.
 
+#### Simplified GCSTaskHandler configuration
+
+In previous versions, in order to configure the service account key file, you had to create a connection entry.
+In the current version, you can configure ``google_key_path`` option in ``[logging]`` section to set
+the key file path.
+
+Users using Application Default Credentials (ADC) need not take any action.
+
+The change aims to simplify the configuration of logging, to prevent corruption of
+the instance configuration by changing the value controlled by the user - connection entry. If you
+configure a backend secret, it also means the webserver doesn't need to connect to it. This
+simplifies setups with multiple GCP projects, because only one project will require the Secret Manager API
+to be enabled.
+
 ### Changes to the core operators/hooks
 
 We strive to ensure that there are no changes that may affect the end user and your files, but this
diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py
index afe4aea..4b5d7f4 100644
--- a/airflow/config_templates/airflow_local_settings.py
+++ b/airflow/config_templates/airflow_local_settings.py
@@ -196,13 +196,15 @@ if REMOTE_LOGGING:
 
         DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS)
     elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'):
+        key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
         GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
             'task': {
-                'class': 'airflow.utils.log.gcs_task_handler.GCSTaskHandler',
+                'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler',
                 'formatter': 'airflow',
                 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
                 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER,
                 'filename_template': FILENAME_TEMPLATE,
+                'gcp_key_path': key_path
             },
         }
 
@@ -222,7 +224,7 @@ if REMOTE_LOGGING:
 
         DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
     elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
-        key_path = conf.get('logging', 'STACKDRIVER_KEY_PATH', fallback=None)
+        key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
         # stackdriver:///airflow-tasks => airflow-tasks
         log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
         STACKDRIVER_REMOTE_HANDLERS = {
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 5a97736..db8182e 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -402,9 +402,9 @@
       type: string
       example: ~
       default: ""
-    - name: stackdriver_key_path
+    - name: google_key_path
       description: |
-        Path to GCP Credential JSON file. If omitted, authorization based on `the Application Default
+        Path to Google Credential JSON file. If omitted, authorization based on `the Application Default
         Credentials
         <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
         be used.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 89514b4..b9d644b 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -230,11 +230,11 @@ remote_logging = False
 # location.
 remote_log_conn_id =
 
-# Path to GCP Credential JSON file. If omitted, authorization based on `the Application Default
+# Path to Google Credential JSON file. If omitted, authorization based on `the Application Default
 # Credentials
 # <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
 # be used.
-stackdriver_key_path =
+google_key_path =
 
 # Storage bucket URL for remote logging
 # S3 buckets should start with "s3://"
diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py
index fbd0dd4..077282c 100644
--- a/airflow/providers/google/cloud/log/gcs_task_handler.py
+++ b/airflow/providers/google/cloud/log/gcs_task_handler.py
@@ -16,15 +16,21 @@
 # specific language governing permissions and limitations
 # under the License.
 import os
-from urllib.parse import urlparse
+from typing import Collection, Optional
 
 from cached_property import cached_property
+from google.api_core.client_info import ClientInfo
+from google.cloud import storage
 
-from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow import version
+from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id
 from airflow.utils.log.file_task_handler import FileTaskHandler
 from airflow.utils.log.logging_mixin import LoggingMixin
 
+_DEFAULT_SCOPESS = frozenset([
+    "https://www.googleapis.com/auth/devstorage.read_write",
+])
+
 
 class GCSTaskHandler(FileTaskHandler, LoggingMixin):
     """
@@ -32,32 +38,64 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
     task instance logs. It extends airflow FileTaskHandler and
     uploads to and reads from GCS remote storage. Upon log reading
     failure, it reads from host machine's local disk.
+
+    :param base_log_folder: Base log folder to place logs.
+    :type base_log_folder: str
+    :param gcs_log_folder: Path to a remote location where logs will be saved. It must have the prefix
+        ``gs://``. For example: ``gs://bucket/remote/log/location``
+    :type gcs_log_folder: str
+    :param filename_template: template filename string
+    :type filename_template: str
+    :param gcp_key_path: Path to GCP Credential JSON file. Mutually exclusive with gcp_keyfile_dict.
+        If omitted, authorization based on `the Application Default Credentials
+        <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
+        be used.
+    :type gcp_key_path: str
+    :param gcp_keyfile_dict: Dictionary of keyfile parameters. Mutually exclusive with gcp_key_path.
+    :type gcp_keyfile_dict: dict
+    :param gcp_scopes: Comma-separated string containing GCP scopes
+    :type gcp_scopes: str
+    :param project_id: Project ID to read the secrets from. If not passed, the project ID from credentials
+        will be used.
+    :type project_id: str
     """
-    def __init__(self, base_log_folder, gcs_log_folder, filename_template):
+    def __init__(
+        self,
+        *,
+        base_log_folder: str,
+        gcs_log_folder: str,
+        filename_template: str,
+        gcp_key_path: Optional[str] = None,
+        gcp_keyfile_dict: Optional[dict] = None,
+        # See: https://github.com/PyCQA/pylint/issues/2377
+        gcp_scopes: Optional[Collection[str]] = _DEFAULT_SCOPESS,  # pylint: disable=unsubscriptable-object
+        project_id: Optional[str] = None,
+    ):
         super().__init__(base_log_folder, filename_template)
         self.remote_base = gcs_log_folder
         self.log_relative_path = ''
         self._hook = None
         self.closed = False
         self.upload_on_close = True
+        self.gcp_key_path = gcp_key_path
+        self.gcp_keyfile_dict = gcp_keyfile_dict
+        self.scopes = gcp_scopes
+        self.project_id = project_id
 
     @cached_property
-    def hook(self):
-        """
-        Returns GCS hook.
-        """
-        remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
-        try:
-            from airflow.providers.google.cloud.hooks.gcs import GCSHook
-            return GCSHook(
-                google_cloud_storage_conn_id=remote_conn_id
-            )
-        except Exception as e:  # pylint: disable=broad-except
-            self.log.error(
-                'Could not create a GoogleCloudStorageHook with connection id '
-                '"%s". %s\n\nPlease make sure that airflow[gcp] is installed '
-                'and the GCS connection exists.', remote_conn_id, str(e)
-            )
+    def client(self) -> storage.Client:
+        """Returns GCS Client."""
+        credentials, project_id = get_credentials_and_project_id(
+            key_path=self.gcp_key_path,
+            keyfile_dict=self.gcp_keyfile_dict,
+            scopes=self.scopes,
+            disable_logging=True
+        )
+        return storage.Client(
+            credentials=credentials,
+            client_info=ClientInfo(client_library_version='airflow_v' + version.version),
+            project=self.project_id if self.project_id else project_id
+        )
 
     def set_context(self, ti):
         super().set_context(ti)
@@ -111,7 +149,8 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         remote_loc = os.path.join(self.remote_base, log_relative_path)
 
         try:
-            remote_log = self.gcs_read(remote_loc)
+            blob = storage.Blob.from_string(remote_loc, self.client)
+            remote_log = blob.download_as_string()
             log = '*** Reading remote log from {}.\n{}\n'.format(
                 remote_loc, remote_log)
             return log, {'end_of_log': True}
@@ -123,19 +162,9 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
             log += local_log
             return log, metadata
 
-    def gcs_read(self, remote_log_location):
-        """
-        Returns the log found at the remote_log_location.
-
-        :param remote_log_location: the log's location in remote storage
-        :type remote_log_location: str (path)
-        """
-        bkt, blob = self.parse_gcs_url(remote_log_location)
-        return self.hook.download(bkt, blob).decode('utf-8')
-
     def gcs_write(self, log, remote_log_location):
         """
-        Writes the log to the remote_log_location. Fails silently if no hook
+        Writes the log to the remote_log_location. Fails silently if no log
         was created.
 
         :param log: the log to write to the remote_log_location
@@ -144,28 +173,16 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         :type remote_log_location: str (path)
         """
         try:
-            old_log = self.gcs_read(remote_log_location)
+            blob = storage.Blob.from_string(remote_log_location, self.client)
+            old_log = blob.download_as_string()
             log = '\n'.join([old_log, log]) if old_log else log
         except Exception as e:  # pylint: disable=broad-except
             if not hasattr(e, 'resp') or e.resp.get('status') != '404':  # pylint: disable=no-member
                 log = '*** Previous log discarded: {}\n\n'.format(str(e)) + log
+                self.log.info("Previous log discarded: %s", e)
 
         try:
-            bkt, blob = self.parse_gcs_url(remote_log_location)
-            self.hook.upload(bkt, blob, data=log)
+            blob = storage.Blob.from_string(remote_log_location, self.client)
+            blob.upload_from_string(log, content_type="text/plain")
         except Exception as e:  # pylint: disable=broad-except
             self.log.error('Could not write logs to %s: %s', remote_log_location, e)
-
-    @staticmethod
-    def parse_gcs_url(gsurl):
-        """
-        Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
-        tuple containing the corresponding bucket and blob.
-        """
-        parsed_url = urlparse(gsurl)
-        if not parsed_url.netloc:
-            raise AirflowException('Please provide a bucket name')
-        else:
-            bucket = parsed_url.netloc
-            blob = parsed_url.path.strip('/')
-            return bucket, blob
diff --git a/docs/howto/write-logs.rst b/docs/howto/write-logs.rst
index 3173b84..c8f858c 100644
--- a/docs/howto/write-logs.rst
+++ b/docs/howto/write-logs.rst
@@ -194,10 +194,11 @@ example:
     # configuration requirements.
     remote_logging = True
     remote_base_log_folder = gs://my-bucket/path/to/logs
-    remote_log_conn_id = MyGCSConn
 
-#. Install the ``google`` package first, like so: ``pip install 'apache-airflow[google]'``.
-#. Make sure a Google Cloud Platform connection hook has been defined in Airflow. The hook should have read and write access to the Google Cloud Storage bucket defined above in ``remote_base_log_folder``.
+#. By default Application Default Credentials are used to obtain credentials. You can also
+   set ``google_key_path`` option in ``[logging]`` section, if you want to use your own service account.
+#. Make sure a Google Cloud Platform account have read and write access to the Google Cloud Storage bucket defined above in ``remote_base_log_folder``.
+#. Install the ``google`` package, like so: ``pip install 'apache-airflow[google]'``.
 #. Restart the Airflow webserver and scheduler, and trigger (or wait for) a new task execution.
 #. Verify that logs are showing up for newly executed tasks in the bucket you've defined.
 #. Verify that the Google Cloud Storage viewer is working in the UI. Pull up a newly executed task, and verify that you see something like:
@@ -311,7 +312,7 @@ For integration with Stackdriver, this option should start with ``stackdriver://
 The path section of the URL specifies the name of the log e.g. ``stackdriver://airflow-tasks`` writes
 logs under the name ``airflow-tasks``.
 
-You can set ``stackdriver_key_path`` option in the ``[logging]`` section to specify the path to `the service
+You can set ``google_key_path`` option in the ``[logging]`` section to specify the path to `the service
 account key file <https://cloud.google.com/iam/docs/service-accounts>`__.
 If omitted, authorization based on `the Application Default Credentials
 <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py
index 2d7042b..0ac5a54 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py
@@ -24,7 +24,6 @@ from unittest import mock
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
 from airflow.operators.dummy_operator import DummyOperator
-from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler
 from airflow.utils.state import State
 from tests.test_utils.config import conf_vars
@@ -47,83 +46,113 @@ class TestGCSTaskHandler(unittest.TestCase):
         self.filename_template = "{try_number}.log"
         self.addCleanup(self.dag.clear)
         self.gcs_task_handler = GCSTaskHandler(
-            self.local_log_location, self.remote_log_base, self.filename_template
+            base_log_folder=self.local_log_location,
+            gcs_log_folder=self.remote_log_base,
+            filename_template=self.filename_template,
         )
 
     def tearDown(self) -> None:
         clear_db_runs()
         shutil.rmtree(self.local_log_location, ignore_errors=True)
 
-    def test_hook(self):
-        self.assertIsInstance(self.gcs_task_handler.hook, GCSHook)
-
-    @conf_vars({("logging", "remote_log_conn_id"): "gcs_default"})
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_hook_raises(self, mock_hook):
-        mock_hook.side_effect = Exception("Failed to connect")
-
-        with self.assertLogs(self.gcs_task_handler.log) as cm:
-            self.gcs_task_handler.hook
-
-        self.assertEqual(
-            cm.output,
-            ['ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
-             'not create a GoogleCloudStorageHook with connection id "gcs_default". Failed '
-             'to connect\n'
-             '\n'
-             'Please make sure that airflow[gcp] is installed and the GCS connection '
-             'exists.']
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    def test_hook(self, mock_client, mock_creds):
+        return_value = self.gcs_task_handler.client
+        mock_client.assert_called_once_with(
+            client_info=mock.ANY, credentials="TEST_CREDENTIALS", project="TEST_PROJECT_ID"
         )
+        self.assertEqual(mock_client.return_value, return_value)
 
     @conf_vars({("logging", "remote_log_conn_id"): "gcs_default"})
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_should_read_logs_from_remote(self, mock_hook):
-        mock_hook.return_value.download.return_value = b"CONTENT"
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    @mock.patch("google.cloud.storage.Blob")
+    def test_should_read_logs_from_remote(self, mock_blob, mock_client, mock_creds):
+        mock_blob.from_string.return_value.download_as_string.return_value = "CONTENT"
 
         logs, metadata = self.gcs_task_handler._read(self.ti, self.ti.try_number)
+        mock_blob.from_string.assert_called_once_with(
+            "gs://bucket/remote/log/location/1.log", mock_client.return_value
+        )
 
-        mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
         self.assertEqual(
-            '*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n', logs)
-        self.assertEqual({'end_of_log': True}, metadata)
+            "*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n", logs
+        )
+        self.assertEqual({"end_of_log": True}, metadata)
 
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_should_read_from_local(self, mock_hook):
-        mock_hook.return_value.download.side_effect = Exception("Failed to connect")
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    @mock.patch("google.cloud.storage.Blob")
+    def test_should_read_from_local(self, mock_blob, mock_client, mock_creds):
+        mock_blob.from_string.return_value.download_as_string.side_effect = Exception("Failed to connect")
 
         self.gcs_task_handler.set_context(self.ti)
-        return_val = self.gcs_task_handler._read(self.ti, self.ti.try_number)
+        log, metadata = self.gcs_task_handler._read(self.ti, self.ti.try_number)
 
-        self.assertEqual(len(return_val), 2)
         self.assertEqual(
-            return_val[0],
+            log,
             "*** Unable to read remote log from gs://bucket/remote/log/location/1.log\n*** "
             f"Failed to connect\n\n*** Reading local file: {self.local_log_location}/1.log\n",
         )
-        self.assertDictEqual(return_val[1], {"end_of_log": True})
-        mock_hook.return_value.download.assert_called_once()
+        self.assertDictEqual(metadata, {"end_of_log": True})
+        mock_blob.from_string.assert_called_once_with(
+            "gs://bucket/remote/log/location/1.log", mock_client.return_value
+        )
 
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_write_to_remote_on_close(self, mock_hook):
-        mock_hook.return_value.download.return_value = b"CONTENT"
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    @mock.patch("google.cloud.storage.Blob")
+    def test_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds):
+        mock_blob.from_string.return_value.download_as_string.return_value = "CONTENT"
 
         self.gcs_task_handler.set_context(self.ti)
-        self.gcs_task_handler.emit(logging.LogRecord(
-            name="NAME", level="DEBUG", pathname=None, lineno=None,
-            msg="MESSAGE", args=None, exc_info=None
-        ))
+        self.gcs_task_handler.emit(
+            logging.LogRecord(
+                name="NAME",
+                level="DEBUG",
+                pathname=None,
+                lineno=None,
+                msg="MESSAGE",
+                args=None,
+                exc_info=None,
+            )
+        )
         self.gcs_task_handler.close()
 
-        mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
-        mock_hook.return_value.upload.assert_called_once_with(
-            'bucket', 'remote/log/location/1.log', data='CONTENT\nMESSAGE\n'
+        mock_blob.assert_has_calls(
+            [
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().download_as_string(),
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().upload_from_string("CONTENT\nMESSAGE\n", content_type="text/plain"),
+            ],
+            any_order=False,
         )
+        mock_blob.from_string.return_value.upload_from_string(data="CONTENT\nMESSAGE\n")
         self.assertEqual(self.gcs_task_handler.closed, True)
 
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_failed_write_to_remote_on_close(self, mock_hook):
-        mock_hook.return_value.upload.side_effect = Exception("Failed to connect")
-        mock_hook.return_value.download.return_value = b"Old log"
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    @mock.patch("google.cloud.storage.Blob")
+    def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds):
+        mock_blob.from_string.return_value.upload_from_string.side_effect = Exception("Failed to connect")
+        mock_blob.from_string.return_value.download_as_string.return_value = b"Old log"
 
         self.gcs_task_handler.set_context(self.ti)
         with self.assertLogs(self.gcs_task_handler.log) as cm:
@@ -132,31 +161,56 @@ class TestGCSTaskHandler(unittest.TestCase):
         self.assertEqual(
             cm.output,
             [
+                'INFO:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Previous '
+                'log discarded: sequence item 0: expected str instance, bytes found',
                 'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
                 'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect'
-            ]
-        )
-        mock_hook.return_value.download.assert_called_once_with(
-            'bucket', 'remote/log/location/1.log'
+            ],
         )
-        mock_hook.return_value.upload.assert_called_once_with(
-            'bucket', 'remote/log/location/1.log', data='Old log\n'
+        mock_blob.assert_has_calls(
+            [
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().download_as_string(),
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().upload_from_string(
+                    "*** Previous log discarded: sequence item 0: expected str instance, bytes found\n\n",
+                    content_type="text/plain",
+                ),
+            ],
+            any_order=False,
         )
 
-    @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
-    def test_write_to_remote_on_close_failed_read_old_logs(self, mock_hook):
-        mock_hook.return_value.download.side_effect = Exception("Fail to download")
+    @mock.patch(
+        "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+        return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+    )
+    @mock.patch("google.cloud.storage.Client")
+    @mock.patch("google.cloud.storage.Blob")
+    def test_write_to_remote_on_close_failed_read_old_logs(self, mock_blob, mock_client, mock_creds):
+        mock_blob.from_string.return_value.download_as_string.side_effect = Exception("Fail to download")
 
         self.gcs_task_handler.set_context(self.ti)
-        self.gcs_task_handler.emit(logging.LogRecord(
-            name="NAME", level="DEBUG", pathname=None, lineno=None,
-            msg="MESSAGE", args=None, exc_info=None
-        ))
+        self.gcs_task_handler.emit(
+            logging.LogRecord(
+                name="NAME",
+                level="DEBUG",
+                pathname=None,
+                lineno=None,
+                msg="MESSAGE",
+                args=None,
+                exc_info=None,
+            )
+        )
         self.gcs_task_handler.close()
 
-        mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
-        mock_hook.return_value.upload.assert_called_once_with(
-            'bucket', 'remote/log/location/1.log',
-            data='*** Previous log discarded: Fail to download\n\nMESSAGE\n'
+        mock_blob.assert_has_calls(
+            [
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().download_as_string(),
+                mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+                mock.call.from_string().upload_from_string(
+                    "*** Previous log discarded: Fail to download\n\nMESSAGE\n", content_type="text/plain"
+                ),
+            ],
+            any_order=False,
         )
-        self.assertEqual(self.gcs_task_handler.closed, True)
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
index 5687746..d35ffd9 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
@@ -54,7 +54,7 @@ class TestStackdriverLoggingHandlerSystemTest(unittest.TestCase):
             'os.environ',
             AIRFLOW__LOGGING__REMOTE_LOGGING="true",
             AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}",
-            AIRFLOW__LOGGING__STACKDRIVER_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER),
+            AIRFLOW__LOGGING__GOOGLE_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER),
             AIRFLOW__CORE__LOAD_EXAMPLES="false",
             AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__
         ):