You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/18 14:25:26 UTC
[airflow] branch master updated: Simplified GCSTaskHandler
configuration (#10365)
This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 083c3c1 Simplified GCSTaskHandler configuration (#10365)
083c3c1 is described below
commit 083c3c129bc3458d410f5ff37d7f5a9a7ad548b7
Author: Kamil BreguĊa <mi...@users.noreply.github.com>
AuthorDate: Tue Aug 18 16:24:26 2020 +0200
Simplified GCSTaskHandler configuration (#10365)
---
UPDATING.md | 14 ++
airflow/config_templates/airflow_local_settings.py | 6 +-
airflow/config_templates/config.yml | 4 +-
airflow/config_templates/default_airflow.cfg | 4 +-
.../providers/google/cloud/log/gcs_task_handler.py | 115 +++++++------
docs/howto/write-logs.rst | 9 +-
.../google/cloud/log/test_gcs_task_handler.py | 188 +++++++++++++--------
.../log/test_stackdriver_task_handler_system.py | 2 +-
8 files changed, 215 insertions(+), 127 deletions(-)
diff --git a/UPDATING.md b/UPDATING.md
index 7392acc..4279b8c 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -561,6 +561,20 @@ better handle the case when a DAG file has multiple DAGs.
Sentry is disabled by default. To enable these integrations, you need set ``sentry_on`` option
in ``[sentry]`` section to ``"True"``.
+#### Simplified GCSTaskHandler configuration
+
+In previous versions, in order to configure the service account key file, you had to create a connection entry.
+In the current version, you can configure ``google_key_path`` option in ``[logging]`` section to set
+the key file path.
+
+Users using Application Default Credentials (ADC) need not take any action.
+
+The change aims to simplify the configuration of logging, to prevent corruption of
+the instance configuration by changing the value controlled by the user - connection entry. If you
+configure a backend secret, it also means the webserver doesn't need to connect to it. This
+simplifies setups with multiple GCP projects, because only one project will require the Secret Manager API
+to be enabled.
+
### Changes to the core operators/hooks
We strive to ensure that there are no changes that may affect the end user and your files, but this
diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py
index afe4aea..4b5d7f4 100644
--- a/airflow/config_templates/airflow_local_settings.py
+++ b/airflow/config_templates/airflow_local_settings.py
@@ -196,13 +196,15 @@ if REMOTE_LOGGING:
DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'):
+ key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
'task': {
- 'class': 'airflow.utils.log.gcs_task_handler.GCSTaskHandler',
+ 'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler',
'formatter': 'airflow',
'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
'gcs_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': FILENAME_TEMPLATE,
+ 'gcp_key_path': key_path
},
}
@@ -222,7 +224,7 @@ if REMOTE_LOGGING:
DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
- key_path = conf.get('logging', 'STACKDRIVER_KEY_PATH', fallback=None)
+ key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
STACKDRIVER_REMOTE_HANDLERS = {
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 5a97736..db8182e 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -402,9 +402,9 @@
type: string
example: ~
default: ""
- - name: stackdriver_key_path
+ - name: google_key_path
description: |
- Path to GCP Credential JSON file. If omitted, authorization based on `the Application Default
+ Path to Google Credential JSON file. If omitted, authorization based on `the Application Default
Credentials
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
be used.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 89514b4..b9d644b 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -230,11 +230,11 @@ remote_logging = False
# location.
remote_log_conn_id =
-# Path to GCP Credential JSON file. If omitted, authorization based on `the Application Default
+# Path to Google Credential JSON file. If omitted, authorization based on `the Application Default
# Credentials
# <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
# be used.
-stackdriver_key_path =
+google_key_path =
# Storage bucket URL for remote logging
# S3 buckets should start with "s3://"
diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py b/airflow/providers/google/cloud/log/gcs_task_handler.py
index fbd0dd4..077282c 100644
--- a/airflow/providers/google/cloud/log/gcs_task_handler.py
+++ b/airflow/providers/google/cloud/log/gcs_task_handler.py
@@ -16,15 +16,21 @@
# specific language governing permissions and limitations
# under the License.
import os
-from urllib.parse import urlparse
+from typing import Collection, Optional
from cached_property import cached_property
+from google.api_core.client_info import ClientInfo
+from google.cloud import storage
-from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow import version
+from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin
+_DEFAULT_SCOPESS = frozenset([
+ "https://www.googleapis.com/auth/devstorage.read_write",
+])
+
class GCSTaskHandler(FileTaskHandler, LoggingMixin):
"""
@@ -32,32 +38,64 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
task instance logs. It extends airflow FileTaskHandler and
uploads to and reads from GCS remote storage. Upon log reading
failure, it reads from host machine's local disk.
+
+ :param base_log_folder: Base log folder to place logs.
+ :type base_log_folder: str
+ :param gcs_log_folder: Path to a remote location where logs will be saved. It must have the prefix
+ ``gs://``. For example: ``gs://bucket/remote/log/location``
+ :type gcs_log_folder: str
+ :param filename_template: template filename string
+ :type filename_template: str
+ :param gcp_key_path: Path to GCP Credential JSON file. Mutually exclusive with gcp_keyfile_dict.
+ If omitted, authorization based on `the Application Default Credentials
+ <https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
+ be used.
+ :type gcp_key_path: str
+ :param gcp_keyfile_dict: Dictionary of keyfile parameters. Mutually exclusive with gcp_key_path.
+ :type gcp_keyfile_dict: dict
+ :param gcp_scopes: Comma-separated string containing GCP scopes
+ :type gcp_scopes: str
+ :param project_id: Project ID to read the secrets from. If not passed, the project ID from credentials
+ will be used.
+ :type project_id: str
"""
- def __init__(self, base_log_folder, gcs_log_folder, filename_template):
+ def __init__(
+ self,
+ *,
+ base_log_folder: str,
+ gcs_log_folder: str,
+ filename_template: str,
+ gcp_key_path: Optional[str] = None,
+ gcp_keyfile_dict: Optional[dict] = None,
+ # See: https://github.com/PyCQA/pylint/issues/2377
+ gcp_scopes: Optional[Collection[str]] = _DEFAULT_SCOPESS, # pylint: disable=unsubscriptable-object
+ project_id: Optional[str] = None,
+ ):
super().__init__(base_log_folder, filename_template)
self.remote_base = gcs_log_folder
self.log_relative_path = ''
self._hook = None
self.closed = False
self.upload_on_close = True
+ self.gcp_key_path = gcp_key_path
+ self.gcp_keyfile_dict = gcp_keyfile_dict
+ self.scopes = gcp_scopes
+ self.project_id = project_id
@cached_property
- def hook(self):
- """
- Returns GCS hook.
- """
- remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
- try:
- from airflow.providers.google.cloud.hooks.gcs import GCSHook
- return GCSHook(
- google_cloud_storage_conn_id=remote_conn_id
- )
- except Exception as e: # pylint: disable=broad-except
- self.log.error(
- 'Could not create a GoogleCloudStorageHook with connection id '
- '"%s". %s\n\nPlease make sure that airflow[gcp] is installed '
- 'and the GCS connection exists.', remote_conn_id, str(e)
- )
+ def client(self) -> storage.Client:
+ """Returns GCS Client."""
+ credentials, project_id = get_credentials_and_project_id(
+ key_path=self.gcp_key_path,
+ keyfile_dict=self.gcp_keyfile_dict,
+ scopes=self.scopes,
+ disable_logging=True
+ )
+ return storage.Client(
+ credentials=credentials,
+ client_info=ClientInfo(client_library_version='airflow_v' + version.version),
+ project=self.project_id if self.project_id else project_id
+ )
def set_context(self, ti):
super().set_context(ti)
@@ -111,7 +149,8 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
remote_loc = os.path.join(self.remote_base, log_relative_path)
try:
- remote_log = self.gcs_read(remote_loc)
+ blob = storage.Blob.from_string(remote_loc, self.client)
+ remote_log = blob.download_as_string()
log = '*** Reading remote log from {}.\n{}\n'.format(
remote_loc, remote_log)
return log, {'end_of_log': True}
@@ -123,19 +162,9 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
log += local_log
return log, metadata
- def gcs_read(self, remote_log_location):
- """
- Returns the log found at the remote_log_location.
-
- :param remote_log_location: the log's location in remote storage
- :type remote_log_location: str (path)
- """
- bkt, blob = self.parse_gcs_url(remote_log_location)
- return self.hook.download(bkt, blob).decode('utf-8')
-
def gcs_write(self, log, remote_log_location):
"""
- Writes the log to the remote_log_location. Fails silently if no hook
+ Writes the log to the remote_log_location. Fails silently if no log
was created.
:param log: the log to write to the remote_log_location
@@ -144,28 +173,16 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
:type remote_log_location: str (path)
"""
try:
- old_log = self.gcs_read(remote_log_location)
+ blob = storage.Blob.from_string(remote_log_location, self.client)
+ old_log = blob.download_as_string()
log = '\n'.join([old_log, log]) if old_log else log
except Exception as e: # pylint: disable=broad-except
if not hasattr(e, 'resp') or e.resp.get('status') != '404': # pylint: disable=no-member
log = '*** Previous log discarded: {}\n\n'.format(str(e)) + log
+ self.log.info("Previous log discarded: %s", e)
try:
- bkt, blob = self.parse_gcs_url(remote_log_location)
- self.hook.upload(bkt, blob, data=log)
+ blob = storage.Blob.from_string(remote_log_location, self.client)
+ blob.upload_from_string(log, content_type="text/plain")
except Exception as e: # pylint: disable=broad-except
self.log.error('Could not write logs to %s: %s', remote_log_location, e)
-
- @staticmethod
- def parse_gcs_url(gsurl):
- """
- Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
- tuple containing the corresponding bucket and blob.
- """
- parsed_url = urlparse(gsurl)
- if not parsed_url.netloc:
- raise AirflowException('Please provide a bucket name')
- else:
- bucket = parsed_url.netloc
- blob = parsed_url.path.strip('/')
- return bucket, blob
diff --git a/docs/howto/write-logs.rst b/docs/howto/write-logs.rst
index 3173b84..c8f858c 100644
--- a/docs/howto/write-logs.rst
+++ b/docs/howto/write-logs.rst
@@ -194,10 +194,11 @@ example:
# configuration requirements.
remote_logging = True
remote_base_log_folder = gs://my-bucket/path/to/logs
- remote_log_conn_id = MyGCSConn
-#. Install the ``google`` package first, like so: ``pip install 'apache-airflow[google]'``.
-#. Make sure a Google Cloud Platform connection hook has been defined in Airflow. The hook should have read and write access to the Google Cloud Storage bucket defined above in ``remote_base_log_folder``.
+#. By default Application Default Credentials are used to obtain credentials. You can also
+ set ``google_key_path`` option in ``[logging]`` section, if you want to use your own service account.
+#. Make sure a Google Cloud Platform account have read and write access to the Google Cloud Storage bucket defined above in ``remote_base_log_folder``.
+#. Install the ``google`` package, like so: ``pip install 'apache-airflow[google]'``.
#. Restart the Airflow webserver and scheduler, and trigger (or wait for) a new task execution.
#. Verify that logs are showing up for newly executed tasks in the bucket you've defined.
#. Verify that the Google Cloud Storage viewer is working in the UI. Pull up a newly executed task, and verify that you see something like:
@@ -311,7 +312,7 @@ For integration with Stackdriver, this option should start with ``stackdriver://
The path section of the URL specifies the name of the log e.g. ``stackdriver://airflow-tasks`` writes
logs under the name ``airflow-tasks``.
-You can set ``stackdriver_key_path`` option in the ``[logging]`` section to specify the path to `the service
+You can set ``google_key_path`` option in the ``[logging]`` section to specify the path to `the service
account key file <https://cloud.google.com/iam/docs/service-accounts>`__.
If omitted, authorization based on `the Application Default Credentials
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py b/tests/providers/google/cloud/log/test_gcs_task_handler.py
index 2d7042b..0ac5a54 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py
@@ -24,7 +24,6 @@ from unittest import mock
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.operators.dummy_operator import DummyOperator
-from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.log.gcs_task_handler import GCSTaskHandler
from airflow.utils.state import State
from tests.test_utils.config import conf_vars
@@ -47,83 +46,113 @@ class TestGCSTaskHandler(unittest.TestCase):
self.filename_template = "{try_number}.log"
self.addCleanup(self.dag.clear)
self.gcs_task_handler = GCSTaskHandler(
- self.local_log_location, self.remote_log_base, self.filename_template
+ base_log_folder=self.local_log_location,
+ gcs_log_folder=self.remote_log_base,
+ filename_template=self.filename_template,
)
def tearDown(self) -> None:
clear_db_runs()
shutil.rmtree(self.local_log_location, ignore_errors=True)
- def test_hook(self):
- self.assertIsInstance(self.gcs_task_handler.hook, GCSHook)
-
- @conf_vars({("logging", "remote_log_conn_id"): "gcs_default"})
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_hook_raises(self, mock_hook):
- mock_hook.side_effect = Exception("Failed to connect")
-
- with self.assertLogs(self.gcs_task_handler.log) as cm:
- self.gcs_task_handler.hook
-
- self.assertEqual(
- cm.output,
- ['ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
- 'not create a GoogleCloudStorageHook with connection id "gcs_default". Failed '
- 'to connect\n'
- '\n'
- 'Please make sure that airflow[gcp] is installed and the GCS connection '
- 'exists.']
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ def test_hook(self, mock_client, mock_creds):
+ return_value = self.gcs_task_handler.client
+ mock_client.assert_called_once_with(
+ client_info=mock.ANY, credentials="TEST_CREDENTIALS", project="TEST_PROJECT_ID"
)
+ self.assertEqual(mock_client.return_value, return_value)
@conf_vars({("logging", "remote_log_conn_id"): "gcs_default"})
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_should_read_logs_from_remote(self, mock_hook):
- mock_hook.return_value.download.return_value = b"CONTENT"
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ @mock.patch("google.cloud.storage.Blob")
+ def test_should_read_logs_from_remote(self, mock_blob, mock_client, mock_creds):
+ mock_blob.from_string.return_value.download_as_string.return_value = "CONTENT"
logs, metadata = self.gcs_task_handler._read(self.ti, self.ti.try_number)
+ mock_blob.from_string.assert_called_once_with(
+ "gs://bucket/remote/log/location/1.log", mock_client.return_value
+ )
- mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
self.assertEqual(
- '*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n', logs)
- self.assertEqual({'end_of_log': True}, metadata)
+ "*** Reading remote log from gs://bucket/remote/log/location/1.log.\nCONTENT\n", logs
+ )
+ self.assertEqual({"end_of_log": True}, metadata)
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_should_read_from_local(self, mock_hook):
- mock_hook.return_value.download.side_effect = Exception("Failed to connect")
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ @mock.patch("google.cloud.storage.Blob")
+ def test_should_read_from_local(self, mock_blob, mock_client, mock_creds):
+ mock_blob.from_string.return_value.download_as_string.side_effect = Exception("Failed to connect")
self.gcs_task_handler.set_context(self.ti)
- return_val = self.gcs_task_handler._read(self.ti, self.ti.try_number)
+ log, metadata = self.gcs_task_handler._read(self.ti, self.ti.try_number)
- self.assertEqual(len(return_val), 2)
self.assertEqual(
- return_val[0],
+ log,
"*** Unable to read remote log from gs://bucket/remote/log/location/1.log\n*** "
f"Failed to connect\n\n*** Reading local file: {self.local_log_location}/1.log\n",
)
- self.assertDictEqual(return_val[1], {"end_of_log": True})
- mock_hook.return_value.download.assert_called_once()
+ self.assertDictEqual(metadata, {"end_of_log": True})
+ mock_blob.from_string.assert_called_once_with(
+ "gs://bucket/remote/log/location/1.log", mock_client.return_value
+ )
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_write_to_remote_on_close(self, mock_hook):
- mock_hook.return_value.download.return_value = b"CONTENT"
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ @mock.patch("google.cloud.storage.Blob")
+ def test_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds):
+ mock_blob.from_string.return_value.download_as_string.return_value = "CONTENT"
self.gcs_task_handler.set_context(self.ti)
- self.gcs_task_handler.emit(logging.LogRecord(
- name="NAME", level="DEBUG", pathname=None, lineno=None,
- msg="MESSAGE", args=None, exc_info=None
- ))
+ self.gcs_task_handler.emit(
+ logging.LogRecord(
+ name="NAME",
+ level="DEBUG",
+ pathname=None,
+ lineno=None,
+ msg="MESSAGE",
+ args=None,
+ exc_info=None,
+ )
+ )
self.gcs_task_handler.close()
- mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
- mock_hook.return_value.upload.assert_called_once_with(
- 'bucket', 'remote/log/location/1.log', data='CONTENT\nMESSAGE\n'
+ mock_blob.assert_has_calls(
+ [
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().download_as_string(),
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().upload_from_string("CONTENT\nMESSAGE\n", content_type="text/plain"),
+ ],
+ any_order=False,
)
+ mock_blob.from_string.return_value.upload_from_string(data="CONTENT\nMESSAGE\n")
self.assertEqual(self.gcs_task_handler.closed, True)
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_failed_write_to_remote_on_close(self, mock_hook):
- mock_hook.return_value.upload.side_effect = Exception("Failed to connect")
- mock_hook.return_value.download.return_value = b"Old log"
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ @mock.patch("google.cloud.storage.Blob")
+ def test_failed_write_to_remote_on_close(self, mock_blob, mock_client, mock_creds):
+ mock_blob.from_string.return_value.upload_from_string.side_effect = Exception("Failed to connect")
+ mock_blob.from_string.return_value.download_as_string.return_value = b"Old log"
self.gcs_task_handler.set_context(self.ti)
with self.assertLogs(self.gcs_task_handler.log) as cm:
@@ -132,31 +161,56 @@ class TestGCSTaskHandler(unittest.TestCase):
self.assertEqual(
cm.output,
[
+ 'INFO:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Previous '
+ 'log discarded: sequence item 0: expected str instance, bytes found',
'ERROR:airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler:Could '
'not write logs to gs://bucket/remote/log/location/1.log: Failed to connect'
- ]
- )
- mock_hook.return_value.download.assert_called_once_with(
- 'bucket', 'remote/log/location/1.log'
+ ],
)
- mock_hook.return_value.upload.assert_called_once_with(
- 'bucket', 'remote/log/location/1.log', data='Old log\n'
+ mock_blob.assert_has_calls(
+ [
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().download_as_string(),
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().upload_from_string(
+ "*** Previous log discarded: sequence item 0: expected str instance, bytes found\n\n",
+ content_type="text/plain",
+ ),
+ ],
+ any_order=False,
)
- @mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook")
- def test_write_to_remote_on_close_failed_read_old_logs(self, mock_hook):
- mock_hook.return_value.download.side_effect = Exception("Fail to download")
+ @mock.patch(
+ "airflow.providers.google.cloud.log.gcs_task_handler.get_credentials_and_project_id",
+ return_value=("TEST_CREDENTIALS", "TEST_PROJECT_ID"),
+ )
+ @mock.patch("google.cloud.storage.Client")
+ @mock.patch("google.cloud.storage.Blob")
+ def test_write_to_remote_on_close_failed_read_old_logs(self, mock_blob, mock_client, mock_creds):
+ mock_blob.from_string.return_value.download_as_string.side_effect = Exception("Fail to download")
self.gcs_task_handler.set_context(self.ti)
- self.gcs_task_handler.emit(logging.LogRecord(
- name="NAME", level="DEBUG", pathname=None, lineno=None,
- msg="MESSAGE", args=None, exc_info=None
- ))
+ self.gcs_task_handler.emit(
+ logging.LogRecord(
+ name="NAME",
+ level="DEBUG",
+ pathname=None,
+ lineno=None,
+ msg="MESSAGE",
+ args=None,
+ exc_info=None,
+ )
+ )
self.gcs_task_handler.close()
- mock_hook.return_value.download.assert_called_once_with('bucket', 'remote/log/location/1.log')
- mock_hook.return_value.upload.assert_called_once_with(
- 'bucket', 'remote/log/location/1.log',
- data='*** Previous log discarded: Fail to download\n\nMESSAGE\n'
+ mock_blob.assert_has_calls(
+ [
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().download_as_string(),
+ mock.call.from_string("gs://bucket/remote/log/location/1.log", mock_client.return_value),
+ mock.call.from_string().upload_from_string(
+ "*** Previous log discarded: Fail to download\n\nMESSAGE\n", content_type="text/plain"
+ ),
+ ],
+ any_order=False,
)
- self.assertEqual(self.gcs_task_handler.closed, True)
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
index 5687746..d35ffd9 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py
@@ -54,7 +54,7 @@ class TestStackdriverLoggingHandlerSystemTest(unittest.TestCase):
'os.environ',
AIRFLOW__LOGGING__REMOTE_LOGGING="true",
AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}",
- AIRFLOW__LOGGING__STACKDRIVER_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER),
+ AIRFLOW__LOGGING__GOOGLE_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER),
AIRFLOW__CORE__LOAD_EXAMPLES="false",
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__
):